In [85]:
# Core libraries
import pandas as pd
import numpy as np
import joblib
from collections import defaultdict
import torch

# Sklearn imports
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
from sklearn.metrics import auc, classification_report, roc_auc_score, confusion_matrix, precision_recall_curve, matthews_corrcoef, ConfusionMatrixDisplay
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.utils import class_weight
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import plot_tree
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    classification_report, roc_auc_score, matthews_corrcoef
)
from sklearn.metrics import roc_curve

# TensorFlow/Keras imports
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Concatenate, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import Precision, Recall, AUC
from tensorflow.keras import backend as K
import torch

#Tabnet
from pytorch_tabnet.tab_model import TabNetClassifier

# Imbalanced Learning
from imblearn.over_sampling import SMOTENC

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Warnings
from warnings import filterwarnings
filterwarnings('ignore')

from collections import Counter

This notebook contains two types of models that try to predict the follow

  • Failure event
  • Maintainence Needs

For each objective a model in built using Random forest an ensemble ML model and TabNet that is built using NN. Therefore there are 4 models in this notebooks. For detailed analysis of models read the report. To make the information redundant, I have only focused on model comparisions of best accuracy in K-fold. I have focued on explaining the models here and the evaluations in the report.

Model explanations have been done only for Failure event, whereas for Maintainence needs only the model and results are displayed as the same model is replicated but different dataset labels are passed.

Failure Event¶

Data Loading and Visualization¶

In [86]:
df = pd.read_csv('military_asset_maintenance_data.csv')
df.head()
Out[86]:
Asset_ID Asset_Type Age_of_Asset Usage_Hours Temperature Pressure Fuel_Consumption Vibration_Levels Humidity Location Maintenance_History Failure_Event
0 1 Ship 9 5740 44.968711 137.776053 49.336583 0.043240 59.341921 Temperate 2 0
1 2 Aircraft 16 8326 74.398522 37.204956 95.947730 0.171963 27.762498 Tropical 2 0
2 3 Ship 11 2667 58.640931 146.077726 72.209295 0.285799 48.473243 Desert 1 0
3 4 Ship 14 8436 73.961007 141.785229 33.231257 0.540367 25.760009 Temperate 1 0
4 5 Aircraft 8 6835 45.759227 56.240621 38.631240 0.204255 60.852518 Tropical 2 0
In [87]:
#amount of null values
df.isnull().sum()
Out[87]:
Asset_ID               0
Asset_Type             0
Age_of_Asset           0
Usage_Hours            0
Temperature            0
Pressure               0
Fuel_Consumption       0
Vibration_Levels       0
Humidity               0
Location               0
Maintenance_History    0
Failure_Event          0
dtype: int64
In [88]:
df.count()
Out[88]:
Asset_ID               50000
Asset_Type             50000
Age_of_Asset           50000
Usage_Hours            50000
Temperature            50000
Pressure               50000
Fuel_Consumption       50000
Vibration_Levels       50000
Humidity               50000
Location               50000
Maintenance_History    50000
Failure_Event          50000
dtype: int64
In [89]:
df.Asset_ID.value_counts().unique()
Out[89]:
array([1])

We can see that the dataset contains no null and duplicate values values

In [90]:
df['Failure_Event'].value_counts()
Out[90]:
Failure_Event
0    42436
1     7564
Name: count, dtype: int64
In [91]:
df['Failure_Event'].value_counts().plot(
    kind='pie', 
    autopct='%1.1f%%', 
    figsize=(4, 4), 
    labels=['No Failure', 'Failure'], 
    startangle=90, 
    explode=(0, 0.1)
)
plt.title("Failure Event Distribution")
plt.ylabel("")  # Remove the default y-axis label
plt.show()
No description has been provided for this image
In [92]:
df.value_counts('Asset_Type')
Out[92]:
Asset_Type
Aircraft    16709
Vehicle     16682
Ship        16609
Name: count, dtype: int64
In [93]:
df[["Asset_Type","Failure_Event"]].value_counts()
Out[93]:
Asset_Type  Failure_Event
Aircraft    0                14204
Vehicle     0                14127
Ship        0                14105
Vehicle     1                 2555
Aircraft    1                 2505
Ship        1                 2504
Name: count, dtype: int64

The data set has an alarminly high imbalanced dataset of 2504 of failures per asset as compared to 14204 of no failure event.

In [94]:
df[["Asset_Type", "Failure_Event"]].value_counts().unstack().plot(kind='bar', stacked=True, figsize=(6, 4))
plt.title("Failure Events by Asset Type")
plt.xlabel("Asset Type")
plt.ylabel("Count")
plt.legend(title="Failure Event", labels=["No Failure", "Failure"])
plt.show()
No description has been provided for this image
In [95]:
numeric_columns = df.drop(columns=["Asset_Type", "Location", "Failure_Event"]).columns

# Set up the plotting grid
plt.figure(figsize=(16, 12))
for i, col in enumerate(numeric_columns, 1):
    plt.subplot(3, 3, i)
    sns.histplot(df[col], kde=True)
    plt.title(f'Distribution of {col}')
plt.tight_layout()
plt.show()
No description has been provided for this image

Asset Information

  • Age_of_Asset: Age ranges mostly between 1 to 18 years. There's a relatively even spread, but a slight dip around ages 6–10 could suggest fewer assets in that age range or maintenance/replacement around that period.
  • Usage_Hours: Fairly uniform distribution up to ~10,000 hours. Indicates that usage varies widely, which is good for training models to learn from both low and high-usage cases.
  • Sensor and Operational Data: Temperature, Pressure, Fuel_Consumption, Vibration_Levels, Humidity
    These features show approximately uniform distributions, with minor fluctuations. This suggests good feature variability, which is beneficial for training — the model has exposure to a wide range of conditions. No clear skewness or major outliers are visible, indicating data is already well-cleaned.
In [96]:
# add code to ingnore categorical
df_corr = df.drop(columns=["Asset_Type", "Location", "Failure_Event"])
# Compute the correlation matrix
corr_matrix = df_corr.corr()

# Set up the plot
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", square=True, cbar_kws={"shrink": .8})
plt.title("Correlation Matrix of Features", fontsize=16)
plt.tight_layout()
plt.show()
No description has been provided for this image

No multicolinearity amoung the variables.

In [97]:
#box plots
plt.figure(figsize=(16, 12))
for i, col in enumerate(numeric_columns, 1):
    plt.subplot(3, 3, i)
    sns.boxplot(x='Failure_Event', y=col, data=df)
    plt.title(f'Box Plot of {col} by Failure Event')
plt.tight_layout()
plt.show()
# Check for outliers
outlier_threshold = 3
outlier_columns = numeric_columns
outliers = {}
for col in outlier_columns:
    z_scores = (df[col] - df[col].mean()) / df[col].std()
    outliers[col] = df[np.abs(z_scores) > outlier_threshold][col]
    print(f"Outliers in {col}:")
    print(outliers[col])
No description has been provided for this image
Outliers in Asset_ID:
Series([], Name: Asset_ID, dtype: int64)
Outliers in Age_of_Asset:
Series([], Name: Age_of_Asset, dtype: int64)
Outliers in Usage_Hours:
Series([], Name: Usage_Hours, dtype: int64)
Outliers in Temperature:
Series([], Name: Temperature, dtype: float64)
Outliers in Pressure:
Series([], Name: Pressure, dtype: float64)
Outliers in Fuel_Consumption:
Series([], Name: Fuel_Consumption, dtype: float64)
Outliers in Vibration_Levels:
Series([], Name: Vibration_Levels, dtype: float64)
Outliers in Humidity:
Series([], Name: Humidity, dtype: float64)
Outliers in Maintenance_History:
Series([], Name: Maintenance_History, dtype: int64)

Most features show some overlap between failure and non-failure events, but a few features exhibit distinct shifts or spread differences, indicating their predictive potential.

  • Older assets and those with higher usage hours tend to experience more failures.

  • Higher vibration levels and elevated temperatures are noticeably associated with failure events, indicating they are strong predictors of potential issues.

  • Fuel consumption and maintenance history show slight increases in failed assets, suggesting possible early warning signs.

  • Pressure and humidity show minimal differences, indicating limited impact individually.


In [ ]:
df[["Location", "Failure_Event"]].value_counts().unstack().plot(kind='bar', stacked=True, figsize=(6, 4))
plt.title("Failure Events by Asset Type")
plt.xlabel("Asset Type")
plt.ylabel("Count")
plt.legend(title="Failure Event", labels=["No Failure", "Failure"])
<matplotlib.legend.Legend at 0x33aa0bdc0>
No description has been provided for this image

Data preparation¶

Mapping categorical variables¶

In [98]:
# mapping columns
categorical_cols = ['Asset_Type', 'Location']

# Dictionary to store mappings
category_mappings = {}

for col in categorical_cols:
    unique_values = df[col].unique()
    mapping = {val: idx for idx, val in enumerate(unique_values)}
    df[col] = df[col].map(mapping)
    category_mappings[col] = mapping


for col, mapping in category_mappings.items():
    print(f"{col} mapping: {mapping}")
Asset_Type mapping: {'Ship': 0, 'Aircraft': 1, 'Vehicle': 2}
Location mapping: {'Temperate': 0, 'Tropical': 1, 'Desert': 2}

Feature Engineering for to capture complex relationships in data¶

In [99]:
df['Thermal_Stress'] = df['Usage_Hours'] * df['Temperature']
df['Age_Vibration_Interaction'] = df['Age_of_Asset'] * df['Vibration_Levels']
df['Fuel_Efficiency'] = df['Fuel_Consumption'] / (df['Usage_Hours'] + 1e-5)  # avoid division by 0
df['Pressure_Temp_Interaction'] = df['Pressure'] * df['Temperature']
df['Operational_Stress_Index'] = (df['Vibration_Levels'] + df['Pressure'] +  df['Temperature'] +  df['Usage_Hours'])/ (df['Age_of_Asset'] + 1e-5)
df.drop(columns=['Asset_ID'], inplace=True)
df.head()
Out[99]:
Asset_Type Age_of_Asset Usage_Hours Temperature Pressure Fuel_Consumption Vibration_Levels Humidity Location Maintenance_History Failure_Event Thermal_Stress Age_Vibration_Interaction Fuel_Efficiency Pressure_Temp_Interaction Operational_Stress_Index
0 0 9 5740 44.968711 137.776053 49.336583 0.043240 59.341921 0 2 0 258120.403835 0.389162 0.008595 6195.611574 658.086825
1 1 16 8326 74.398522 37.204956 95.947730 0.171963 27.762498 1 2 0 619442.090624 2.751404 0.011524 2767.993702 527.360635
2 0 11 2667 58.640931 146.077726 72.209295 0.285799 48.473243 2 1 0 156395.363951 3.143794 0.027075 8566.133914 261.091077
3 0 14 8436 73.961007 141.785229 33.231257 0.540367 25.760009 0 1 0 623935.057788 7.565140 0.003939 10486.578333 618.020030
4 1 8 6835 45.759227 56.240621 38.631240 0.204255 60.852518 1 2 0 312764.319305 1.634040 0.005652 2573.527343 867.149429

Feature Engineering Rationale¶

These features are designed to:

  • Ccapture interactions between variables that may not be obvious from individual inputs.
  • Reflect real-world mechanical stress patterns.
  • Improve model predictive performance by injecting domain knowledge into the data.
  1. Thermal_Stress = Usage_Hours * Temperature
  • Captures the cumulative thermal load an asset experiences.
  • Reflects extended usage in high-temperature conditions, which can accelerate wear and tear.
  1. Age_Vibration_Interaction = Age_of_Asset * Vibration_Levels
  • Combines mechanical wear (vibration) with asset aging.
  • Older assets with high vibration levels are more prone to failure, making this an important risk indicator.
  1. Fuel_Efficiency = Fuel_Consumption / (Usage_Hours + 1e-5)
  • Represents how efficiently an asset consumes fuel relative to its operational time.
  • Declining efficiency may signal engine or system degradation, indicating the need for maintenance.
  1. Pressure_Temp_Interaction = Pressure * Temperature
  • Models the combined stress of internal pressure and temperature.
  • High values could indicate overheating or internal system strain, which may lead to failure.
  1. Operational_Stress_Index = (Vibration_Levels + Pressure + Temperature + Usage_Hours) / (Age_of_Asset + 1e-5)
  • A composite score that reflects the overall operational stress experienced by an asset.
  • Normalized by asset age to provide a measure of how much load the asset is handling relative to its lifecycle.
  • Higher values may suggest overuse or abnormal stress, potentially increasing failure risk.

These engineered features introduce meaningful interactions that are likely to improve model performance, particularly in scenarios involving military asset maintenance, where wear-and-tear dynamics are complex and interdependent.


In [100]:
corr_matrix = df.corr()

# Set up the plot
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", square=True, cbar_kws={"shrink": .8})
plt.title("Correlation Matrix of Features", fontsize=16)
plt.tight_layout()
plt.show()
No description has been provided for this image

Engineered features are mathematically and contextually justified, and their correlations confirm their relevance.

Since no feature has a strong linear relationship with Failure_Event, feature interactions and ensemble models are critical.

Some features (like Humidity, Location, and Asset_Type) have near-zero correlation with most variables, which may warrant further analysis or dimensionality reduction.


Data Pre Processing¶

In [101]:
# Features and target
X = df.drop(columns=['Failure_Event'])
y = df['Failure_Event']

# Define categorical and numerical features
cat_features = ['Asset_Type', 'Location', 'Maintenance_History']
num_features = [col for col in X.columns if col not in cat_features]

# Train-test split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# --- Apply SMOTE to balance the classes in the training data ---
categorical_indices = [X_train.columns.get_loc(col) for col in cat_features]

print("Before SMOTE:", Counter(y_train))

# Apply SMOTENC (for categorical and continuous features)
smote_nc = SMOTENC(
    categorical_features=categorical_indices,
    sampling_strategy=0.5,  # Make minority class 50% the size of the majority class
    random_state=42
)

X_train_balanced, y_train_balanced = smote_nc.fit_resample(X_train, y_train)

# Convert back to DataFrame for further processing
X_train_final = pd.DataFrame(X_train_balanced, columns=X_train.columns)
y_train_final = pd.Series(y_train_balanced, name='Failure_Event')

print("After SMOTE:", Counter(y_train_balanced))

# --- Scale numeric features after SMOTE ---
scaler = StandardScaler()

# Separate numerical columns for scaling
X_train_num = X_train_final[num_features]
X_val_num = X_val[num_features]

X_train_num_scaled = scaler.fit_transform(X_train_num)
X_val_num_scaled = scaler.transform(X_val_num)

# Convert back to DataFrame
X_train_num_scaled_df = pd.DataFrame(X_train_num_scaled, columns=num_features).reset_index(drop=True)
X_val_num_scaled_df = pd.DataFrame(X_val_num_scaled, columns=num_features).reset_index(drop=True)

# Prepare final training and validation sets by combining scaled numeric and encoded categorical features
X_train_final = pd.concat([X_train_num_scaled_df, X_train_final[cat_features].reset_index(drop=True)], axis=1)
X_val_final = pd.concat([X_val_num_scaled_df, X_val[cat_features].reset_index(drop=True)], axis=1)
Before SMOTE: Counter({0: 33949, 1: 6051})
After SMOTE: Counter({0: 33949, 1: 16974})

  1. Feature Selection

    • Target variable Failure_Event was separated from the features.
    • Categorical features: Asset_Type, Location, Maintenance_History.
    • Remaining columns treated as numerical features.
  2. Train-Test Split

    • Data was split 80/20 using stratified sampling to maintain class distribution.
  3. Class Imbalance Handling

    • Applied SMOTENC to oversample the minority class (Failure_Event = 1) while handling both categorical and numerical data.
    • Sampling strategy increased the minority class to 50% of the majority. This will allow me to use class weights.
  4. Feature Scaling

    • Numerical features were scaled using StandardScaler to normalize distributions.
    • Categorical features were preserved without scaling.
  5. Final Dataset Preparation

    • Scaled numerical and raw categorical features were combined to form X_train_final and X_val_final.

Even after applying SMOTENC, class imbalance may still persist slightly.

Using class weights in the model ensures that:

  • The algorithm pays more attention to the minority class (failures), which is critical in high-risk domains like military maintenance.
  • It compensates for any remaining imbalance, improving the model's sensitivity to rare but important failure events.
  • It works synergistically with SMOTENC — SMOTENC balances the data distribution, while class weights adjust the loss function during training.

NOTE: I have experimented using only class weights without any SMOTENC, but the model didn't predict the minority class. I tried using SMOTENC without class weights, the model behaved the same. Thus I have used SMOTENC with minority class oversampled to consitituent only 50% of majority class and used class weights where required.


Random Forest¶

We will first use grid search to find the best parameters for the Random Forest model. The model will then be trained using the k-fold cross validation.

Grid Search for best parameters¶

In [102]:
#Set up a Random Forest with basic pruning options
rf = RandomForestClassifier(random_state=42, n_jobs=-1, class_weight='balanced')

#Grid Search for best hyperparameters including pruning-related ones
param_grid = {
    'n_estimators': [100],
    'max_depth': [3, 5, 7, None],  # Control tree size (pruning)
    'min_samples_split': [2, 5, 10],  # Prevent overgrowth
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2']
}

grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=3, scoring='roc_auc', verbose=1)
grid_search.fit(X_train_balanced, y_train_balanced) # using unscalled data for Random Forest

#Get best model and evaluate
best_rf_F = grid_search.best_estimator_
y_pred = best_rf_F.predict(X_val)
y_prob = best_rf_F.predict_proba(X_val)[:, 1]

print("Best Params:", grid_search.best_params_)
print("AUC-ROC:", roc_auc_score(y_val, y_prob))
print(classification_report(y_val, y_pred, digits=4))
mcc = matthews_corrcoef(y_val, y_pred)
print(f"MCC: {mcc:.4f}")

#Plot one of the trees in the forest
plt.figure(figsize=(30, 20))
plot_tree(best_rf_F.estimators_[0], 
          feature_names=X.columns, 
          class_names=['Class 0', 'Class 1'], 
          filled=True, 
          rounded=True,
          max_depth=3,
          fontsize=10)  # Only show top 3 levels for clarity
plt.title("Random Forest - Tree Visualization")
plt.show()

joblib.dump(best_rf_F, 'Failure_event_random_forest_model.joblib')
Fitting 3 folds for each of 72 candidates, totalling 216 fits
Best Params: {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}
AUC-ROC: 0.5029605171191802
              precision    recall  f1-score   support

           0     0.8495    0.9390    0.8920      8487
           1     0.1632    0.0668    0.0947      1513

    accuracy                         0.8070     10000
   macro avg     0.5063    0.5029    0.4934     10000
weighted avg     0.7456    0.8070    0.7714     10000

MCC: 0.0085
No description has been provided for this image
Out[102]:
['Failure_event_random_forest_model.joblib']

Random Forest Classifier with Pruning and Hyperparameter Tuning¶

Model Setup¶

A RandomForestClassifier is initialized with:

  • class_weight='balanced' to handle class imbalance by penalizing the majority class during training.
  • random_state=42 ensures reproducibility.
  • n_jobs=-1 enables parallel computation for faster training.
Hyperparameter Tuning using GridSearchCV¶

A grid search is performed over key pruning-related hyperparameters:

  • n_estimators: Number of trees in the forest (fixed at 100).
  • max_depth: Maximum depth of each tree to prevent overfitting.
  • min_samples_split: Minimum samples required to split a node (controls granularity).
  • min_samples_leaf: Minimum samples at a leaf node (avoids overly specific rules).
  • max_features: Number of features to consider when splitting a node (sqrt, log2 help reduce variance).

Scoring Metric: roc_auc
Cross-validation: 3-fold (cv=3)

Note: X_train_balanced is used without scaling, as Random Forests are scale-invariant.

Model Evaluation¶

The best model from GridSearchCV is used to predict on the validation set:

  • AUC-ROC is computed to measure overall classification performance.
  • A full classification_report is printed (precision, recall, F1-score).
  • Matthews Correlation Coefficient (MCC) is calculated to provide a balanced metric even with class imbalance.
Visualizing One Decision Tree¶

A single tree (first in the ensemble) is visualized with:

  • max_depth=3 for clarity.
  • Feature names and class labels included.
  • Colors and shapes enhance interpretability.

Random Forest is an ensemble, but visualizing one tree helps understand individual decision paths.

Model Saving¶

The trained best model is saved using joblib for future inference:

joblib.dump(best_rf_F, 'Failure_event_random_forest_model.joblib')

Random Forest Tree Interpretation from Random Forest (Top 3 Levels)

Key Features Used for Splitting

  1. Vibration_Levels (Root Node)

    • Most important feature at the root.
    • Lower vibration values are associated with higher likelihood of failure (Class 1).
  2. Age_Vibration_Interaction

    • Combines asset age and vibration to capture compound degradation.
    • Lower interaction values (younger assets with some vibration) lean toward Class 0.
  3. Pressure and Operational_Stress_Index

    • Further refine splits based on operational intensity.
    • High stress or pressure correlates with different failure risks.
  4. Fuel_Consumption

    • Used repeatedly at different depths.
    • Low consumption is often associated with failed assets.
  5. Thermal_Stress and Pressure_Temp_Interaction

    • Capture specific mechanical strain or operational load effects.
    • Moderate impact on classification in sub-branches.

General Observations

  • The model is using engineered features effectively (e.g., Age_Vibration_Interaction, Operational_Stress_Index, Thermal_Stress), showing their importance.
  • Failure prediction (Class 1) is influenced by combinations of operational intensity and asset condition.
  • Decision splits generally reflect intuitive domain logic: high stress, abnormal vibration, or low fuel efficiency increases failure risk.

Decision Tree Depth and Clarity

  • Only the top 3 levels are visualized to maintain clarity.
  • Beyond this, additional splits continue refining classifications using similar or supporting features.

This tree is a single estimator from the Random Forest — useful for interpretability, but the final prediction is made by aggregating across all trees.


K-fold on Best Rf model¶

In [104]:
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import (
    roc_auc_score,
    classification_report,
    matthews_corrcoef,
    confusion_matrix,
    ConfusionMatrixDisplay,
    precision_recall_curve,
    auc  # re-importing packages so I can use it without any error
)

# Define Stratified K-Fold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Store metrics for summary
roc_auc_scores = []

# Loop over each fold
for fold, (train_idx, val_idx) in enumerate(cv.split(X_train_balanced, y_train_balanced), start=1):
    X_train_fold, X_val_fold = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
    y_train_fold, y_val_fold = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]
    
    # Train the best model on this fold
    best_rf_F.fit(X_train_fold, y_train_fold)
    y_pred = best_rf_F.predict(X_val_fold)
    y_proba = best_rf_F.predict_proba(X_val_fold)[:, 1]
    
    # AUC for this fold
    roc_auc = roc_auc_score(y_val_fold, y_proba)
    roc_auc_scores.append(roc_auc)
    
    # Print classification report
    print(f"\nFold {fold} - Classification Report:")
    print(classification_report(y_val_fold, y_pred, digits=4))
    print(f"Fold {fold} - AUC-ROC: {roc_auc:.4f}")
    mcc = matthews_corrcoef(y_val_fold, y_pred)
    print(f"MCC: {mcc:.4f}")
    # Confusion Matrix
    cm = confusion_matrix(y_val_fold, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='Blues')
    plt.title("Confusion Matrix")
    plt.show()

    # Compute ROC curve and AUC
    fpr, tpr, thresholds = roc_curve(y_val_fold, y_proba)
    roc_auc_curve = auc(fpr, tpr)
    
    # historgraph of predicted probabilities
    plt.figure(figsize=(6, 4))
    plt.hist(y_proba[y_val_fold == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
    plt.hist(y_proba[y_val_fold == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
    plt.xlabel('Predicted Probability')
    plt.ylabel('Count')
    plt.title('Histogram of Predicted Probabilities by Class')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

    # Plot ROC Curve
    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC AUC = {roc_auc:.4f}")
    plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')  # Diagonal line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

# Print average AUC over all folds
print("\nAverage AUC-ROC across folds:")
print(f"{np.mean(roc_auc_scores):.4f} ± {np.std(roc_auc_scores):.4f}")
Fold 1 - Classification Report:
              precision    recall  f1-score   support

           0     0.7487    0.9405    0.8337      6790
           1     0.7560    0.3688    0.4957      3395

    accuracy                         0.7499     10185
   macro avg     0.7524    0.6546    0.6647     10185
weighted avg     0.7512    0.7499    0.7211     10185

Fold 1 - AUC-ROC: 0.7769
MCC: 0.3951
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 2 - Classification Report:
              precision    recall  f1-score   support

           0     0.7461    0.9471    0.8347      6790
           1     0.7706    0.3552    0.4863      3395

    accuracy                         0.7498     10185
   macro avg     0.7583    0.6512    0.6605     10185
weighted avg     0.7542    0.7498    0.7185     10185

Fold 2 - AUC-ROC: 0.7795
MCC: 0.3952
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 3 - Classification Report:
              precision    recall  f1-score   support

           0     0.7450    0.9367    0.8299      6790
           1     0.7391    0.3588    0.4830      3395

    accuracy                         0.7440     10185
   macro avg     0.7420    0.6477    0.6565     10185
weighted avg     0.7430    0.7440    0.7143     10185

Fold 3 - AUC-ROC: 0.7722
MCC: 0.3782
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 4 - Classification Report:
              precision    recall  f1-score   support

           0     0.7445    0.9412    0.8314      6790
           1     0.7506    0.3539    0.4810      3394

    accuracy                         0.7455     10184
   macro avg     0.7476    0.6475    0.6562     10184
weighted avg     0.7466    0.7455    0.7146     10184

Fold 4 - AUC-ROC: 0.7706
MCC: 0.3823
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 5 - Classification Report:
              precision    recall  f1-score   support

           0     0.7460    0.9378    0.8310      6789
           1     0.7441    0.3614    0.4865      3395

    accuracy                         0.7457     10184
   macro avg     0.7450    0.6496    0.6588     10184
weighted avg     0.7454    0.7457    0.7162     10184

Fold 5 - AUC-ROC: 0.7721
MCC: 0.3830
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Average AUC-ROC across folds:
0.7742 ± 0.0034

TAB NET¶

In [105]:
# Configuration
TARGET_COL = 'Failure_Event'
CATEGORICAL_COLS = ['Asset_Type', 'Location', 'Maintenance_History']
RANDOM_STATE = 42
N_SPLITS = 5
EPOCHS = 200
BATCH_SIZE = 64

# Identify categorical feature indices and dimensions
cat_idxs = [X_train_balanced.columns.get_loc(col) for col in CATEGORICAL_COLS]
cat_dims = [int(df[col].nunique()) for col in CATEGORICAL_COLS]
cat_emb_dim = [min(50, (dim + 1) // 2) for dim in cat_dims]

# Cross-validation setup
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

fold = 1
auc_scores = []

classes = np.unique(y_train_balanced)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train_balanced)
weights = dict(zip(classes, class_weights))

for train_idx, val_idx in skf.split(X_train_balanced, y_train_balanced):
    print(f"\n==== Fold {fold} ====")
    
    TAB_X_train, TAB_X_val = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
    TAB_y_train,TAB_y_val = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]

    # Initialize and train the model
    clf_F = TabNetClassifier(
        n_d=32, n_a=32, n_steps=5, gamma=1.5,
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=cat_emb_dim,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=1e-2),
        scheduler_params={"step_size":10, "gamma":0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='entmax',
        seed=RANDOM_STATE,
        verbose=1
    )

    clf_F.fit(
        X_train=TAB_X_train.values, y_train=TAB_y_train.values,
        eval_set=[(TAB_X_val.values, TAB_y_val.values)],
        eval_name=['val'],
        eval_metric=['auc'],
        max_epochs=EPOCHS,
        patience=20,
        batch_size=BATCH_SIZE,
        virtual_batch_size=128,
        weights = weights
    )

    # Evaluation
    y_pred_proba = clf_F.predict_proba(TAB_X_val.values)[:, 1]
    y_pred = clf_F.predict(TAB_X_val.values)
    auc = roc_auc_score(TAB_y_val.values, y_pred_proba)
    print(classification_report(TAB_y_val.values, y_pred, digits=4))
    print(f"Fold {fold} AUC: {auc:.4f}")
    auc_scores.append(auc)
    mcc = matthews_corrcoef(TAB_y_val, y_pred)
    print(f"MCC: {mcc:.4f}")

    # Confusion matrix
    cm = confusion_matrix(TAB_y_val, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='Blues')
    plt.title("Confusion Matrix")
    plt.show()

    from sklearn.metrics import roc_curve, auc

    fpr, tpr, thresholds = roc_curve(TAB_y_val, y_pred_proba)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid()
    plt.tight_layout()
    plt.show()


    plt.figure(figsize=(6, 4))
    plt.hist(y_pred_proba[TAB_y_val == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
    plt.hist(y_pred_proba[TAB_y_val == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
    plt.xlabel('Predicted Probability')
    plt.ylabel('Count')
    plt.title('Histogram of Predicted Probabilities by Class')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()
    
    fold += 1

# Final average AUC
print(f"\n==== Cross-Validation Complete ====")
print(f"Mean AUC: {np.mean(auc_scores):.4f} | Std AUC: {np.std(auc_scores):.4f}")

clf_F.save_model("Failure_Event_tabnet_model")
==== Fold 1 ====
epoch 0  | loss: 0.71458 | val_auc: 0.58514 |  0:00:08s
epoch 1  | loss: 0.68575 | val_auc: 0.59763 |  0:00:15s
epoch 2  | loss: 0.68108 | val_auc: 0.61216 |  0:00:23s
epoch 3  | loss: 0.67893 | val_auc: 0.61338 |  0:00:31s
epoch 4  | loss: 0.67825 | val_auc: 0.61639 |  0:00:39s
epoch 5  | loss: 0.6744  | val_auc: 0.61304 |  0:00:47s
epoch 6  | loss: 0.67283 | val_auc: 0.6182  |  0:00:55s
epoch 7  | loss: 0.6697  | val_auc: 0.6287  |  0:01:04s
epoch 8  | loss: 0.67066 | val_auc: 0.63418 |  0:01:12s
epoch 9  | loss: 0.66839 | val_auc: 0.63039 |  0:01:21s
epoch 10 | loss: 0.67151 | val_auc: 0.62486 |  0:01:29s
epoch 11 | loss: 0.67072 | val_auc: 0.61996 |  0:01:38s
epoch 12 | loss: 0.67    | val_auc: 0.63585 |  0:01:49s
epoch 13 | loss: 0.66961 | val_auc: 0.62514 |  0:01:58s
epoch 14 | loss: 0.66823 | val_auc: 0.63112 |  0:02:07s
epoch 15 | loss: 0.66733 | val_auc: 0.63498 |  0:02:15s
epoch 16 | loss: 0.66758 | val_auc: 0.6419  |  0:02:24s
epoch 17 | loss: 0.66449 | val_auc: 0.64695 |  0:02:34s
epoch 18 | loss: 0.66605 | val_auc: 0.64354 |  0:02:42s
epoch 19 | loss: 0.6628  | val_auc: 0.64159 |  0:02:51s
epoch 20 | loss: 0.66557 | val_auc: 0.64505 |  0:03:00s
epoch 21 | loss: 0.66332 | val_auc: 0.64416 |  0:03:09s
epoch 22 | loss: 0.66111 | val_auc: 0.64345 |  0:03:19s
epoch 23 | loss: 0.66118 | val_auc: 0.64386 |  0:03:29s
epoch 24 | loss: 0.66016 | val_auc: 0.65331 |  0:03:39s
epoch 25 | loss: 0.65964 | val_auc: 0.65088 |  0:03:47s
epoch 26 | loss: 0.65719 | val_auc: 0.65574 |  0:03:56s
epoch 27 | loss: 0.6491  | val_auc: 0.671   |  0:04:04s
epoch 28 | loss: 0.63201 | val_auc: 0.66614 |  0:04:13s
epoch 29 | loss: 0.61198 | val_auc: 0.69575 |  0:04:22s
epoch 30 | loss: 0.58723 | val_auc: 0.70743 |  0:04:30s
epoch 31 | loss: 0.56325 | val_auc: 0.74714 |  0:04:40s
epoch 32 | loss: 0.54804 | val_auc: 0.7355  |  0:04:48s
epoch 33 | loss: 0.54409 | val_auc: 0.755   |  0:04:57s
epoch 34 | loss: 0.53694 | val_auc: 0.71618 |  0:05:06s
epoch 35 | loss: 0.52921 | val_auc: 0.73014 |  0:05:16s
epoch 36 | loss: 0.5318  | val_auc: 0.70244 |  0:05:25s
epoch 37 | loss: 0.52623 | val_auc: 0.71791 |  0:05:34s
epoch 38 | loss: 0.52309 | val_auc: 0.74466 |  0:05:45s
epoch 39 | loss: 0.52391 | val_auc: 0.73747 |  0:05:54s
epoch 40 | loss: 0.51575 | val_auc: 0.71515 |  0:06:04s
epoch 41 | loss: 0.51472 | val_auc: 0.75724 |  0:06:13s
epoch 42 | loss: 0.51097 | val_auc: 0.76026 |  0:06:23s
epoch 43 | loss: 0.51218 | val_auc: 0.71206 |  0:06:32s
epoch 44 | loss: 0.51132 | val_auc: 0.63665 |  0:06:41s
epoch 45 | loss: 0.51298 | val_auc: 0.75833 |  0:06:50s
epoch 46 | loss: 0.50666 | val_auc: 0.75484 |  0:07:00s
epoch 47 | loss: 0.50242 | val_auc: 0.74443 |  0:07:10s
epoch 48 | loss: 0.50202 | val_auc: 0.67663 |  0:07:20s
epoch 49 | loss: 0.50026 | val_auc: 0.72603 |  0:07:29s
epoch 50 | loss: 0.49844 | val_auc: 0.70896 |  0:07:39s
epoch 51 | loss: 0.49932 | val_auc: 0.69035 |  0:07:48s
epoch 52 | loss: 0.49436 | val_auc: 0.75091 |  0:07:57s
epoch 53 | loss: 0.4942  | val_auc: 0.73055 |  0:08:05s
epoch 54 | loss: 0.49471 | val_auc: 0.76354 |  0:08:14s
epoch 55 | loss: 0.49058 | val_auc: 0.66817 |  0:08:22s
epoch 56 | loss: 0.49392 | val_auc: 0.72228 |  0:08:31s
epoch 57 | loss: 0.49568 | val_auc: 0.74181 |  0:08:39s
epoch 58 | loss: 0.4904  | val_auc: 0.6678  |  0:08:48s
epoch 59 | loss: 0.49121 | val_auc: 0.73797 |  0:08:56s
epoch 60 | loss: 0.48977 | val_auc: 0.68416 |  0:09:05s
epoch 61 | loss: 0.48328 | val_auc: 0.76191 |  0:09:14s
epoch 62 | loss: 0.48807 | val_auc: 0.70157 |  0:09:22s
epoch 63 | loss: 0.48529 | val_auc: 0.69813 |  0:09:31s
epoch 64 | loss: 0.48864 | val_auc: 0.77281 |  0:09:40s
epoch 65 | loss: 0.4869  | val_auc: 0.73468 |  0:09:48s
epoch 66 | loss: 0.4867  | val_auc: 0.66701 |  0:09:56s
epoch 67 | loss: 0.48994 | val_auc: 0.73992 |  0:10:08s
epoch 68 | loss: 0.48501 | val_auc: 0.74521 |  0:10:28s
epoch 69 | loss: 0.48666 | val_auc: 0.74401 |  0:10:41s
epoch 70 | loss: 0.48619 | val_auc: 0.72978 |  0:10:53s
epoch 71 | loss: 0.48218 | val_auc: 0.74547 |  0:11:06s
epoch 72 | loss: 0.48358 | val_auc: 0.76971 |  0:11:17s
epoch 73 | loss: 0.47678 | val_auc: 0.75136 |  0:11:28s
epoch 74 | loss: 0.48124 | val_auc: 0.7399  |  0:11:39s
epoch 75 | loss: 0.48078 | val_auc: 0.77214 |  0:11:49s
epoch 76 | loss: 0.47987 | val_auc: 0.70952 |  0:11:58s
epoch 77 | loss: 0.4761  | val_auc: 0.73663 |  0:12:07s
epoch 78 | loss: 0.4817  | val_auc: 0.74254 |  0:12:16s
epoch 79 | loss: 0.47893 | val_auc: 0.74586 |  0:12:25s
epoch 80 | loss: 0.47633 | val_auc: 0.74668 |  0:12:35s
epoch 81 | loss: 0.47451 | val_auc: 0.75681 |  0:12:43s
epoch 82 | loss: 0.47598 | val_auc: 0.72161 |  0:12:52s
epoch 83 | loss: 0.47235 | val_auc: 0.73272 |  0:13:00s
epoch 84 | loss: 0.47724 | val_auc: 0.74193 |  0:13:09s

Early stopping occurred at epoch 84 with best_epoch = 64 and best_val_auc = 0.77281
              precision    recall  f1-score   support

           0     0.7991    0.9710    0.8767      6790
           1     0.8981    0.5116    0.6519      3395

    accuracy                         0.8179     10185
   macro avg     0.8486    0.7413    0.7643     10185
weighted avg     0.8321    0.8179    0.8017     10185

Fold 1 AUC: 0.7728
MCC: 0.5801
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 2 ====
epoch 0  | loss: 0.7257  | val_auc: 0.53746 |  0:00:08s
epoch 1  | loss: 0.68761 | val_auc: 0.58829 |  0:00:16s
epoch 2  | loss: 0.68516 | val_auc: 0.59725 |  0:00:24s
epoch 3  | loss: 0.68323 | val_auc: 0.59705 |  0:00:32s
epoch 4  | loss: 0.67999 | val_auc: 0.60292 |  0:00:40s
epoch 5  | loss: 0.68029 | val_auc: 0.61052 |  0:00:48s
epoch 6  | loss: 0.67869 | val_auc: 0.6098  |  0:00:56s
epoch 7  | loss: 0.67758 | val_auc: 0.60534 |  0:01:03s
epoch 8  | loss: 0.6794  | val_auc: 0.60925 |  0:01:11s
epoch 9  | loss: 0.67816 | val_auc: 0.60583 |  0:01:19s
epoch 10 | loss: 0.67685 | val_auc: 0.60395 |  0:01:27s
epoch 11 | loss: 0.67939 | val_auc: 0.60684 |  0:01:34s
epoch 12 | loss: 0.68099 | val_auc: 0.59842 |  0:01:42s
epoch 13 | loss: 0.67846 | val_auc: 0.60648 |  0:01:49s
epoch 14 | loss: 0.67964 | val_auc: 0.60525 |  0:01:57s
epoch 15 | loss: 0.67872 | val_auc: 0.61034 |  0:02:04s
epoch 16 | loss: 0.68044 | val_auc: 0.60728 |  0:02:12s
epoch 17 | loss: 0.67801 | val_auc: 0.60308 |  0:02:20s
epoch 18 | loss: 0.67865 | val_auc: 0.60383 |  0:02:27s
epoch 19 | loss: 0.67857 | val_auc: 0.60348 |  0:02:35s
epoch 20 | loss: 0.67905 | val_auc: 0.60752 |  0:02:42s
epoch 21 | loss: 0.67902 | val_auc: 0.60377 |  0:02:50s
epoch 22 | loss: 0.67894 | val_auc: 0.60094 |  0:02:57s
epoch 23 | loss: 0.67621 | val_auc: 0.60858 |  0:03:05s
epoch 24 | loss: 0.67694 | val_auc: 0.61065 |  0:03:12s
epoch 25 | loss: 0.67631 | val_auc: 0.61026 |  0:03:20s
epoch 26 | loss: 0.67678 | val_auc: 0.61128 |  0:03:28s
epoch 27 | loss: 0.67715 | val_auc: 0.61047 |  0:03:35s
epoch 28 | loss: 0.67667 | val_auc: 0.60139 |  0:03:43s
epoch 29 | loss: 0.67677 | val_auc: 0.61035 |  0:03:51s
epoch 30 | loss: 0.67653 | val_auc: 0.61154 |  0:03:58s
epoch 31 | loss: 0.67662 | val_auc: 0.61235 |  0:04:06s
epoch 32 | loss: 0.67587 | val_auc: 0.61207 |  0:04:14s
epoch 33 | loss: 0.67723 | val_auc: 0.61022 |  0:04:21s
epoch 34 | loss: 0.67644 | val_auc: 0.61357 |  0:04:29s
epoch 35 | loss: 0.67582 | val_auc: 0.6133  |  0:04:37s
epoch 36 | loss: 0.67542 | val_auc: 0.61454 |  0:04:45s
epoch 37 | loss: 0.6765  | val_auc: 0.61416 |  0:04:52s
epoch 38 | loss: 0.67528 | val_auc: 0.61201 |  0:05:00s
epoch 39 | loss: 0.67663 | val_auc: 0.61751 |  0:05:07s
epoch 40 | loss: 0.67595 | val_auc: 0.61399 |  0:05:15s
epoch 41 | loss: 0.6749  | val_auc: 0.61576 |  0:05:23s
epoch 42 | loss: 0.67498 | val_auc: 0.61713 |  0:05:31s
epoch 43 | loss: 0.67593 | val_auc: 0.6126  |  0:05:40s
epoch 44 | loss: 0.67587 | val_auc: 0.61462 |  0:05:48s
epoch 45 | loss: 0.67589 | val_auc: 0.61386 |  0:05:56s
epoch 46 | loss: 0.67656 | val_auc: 0.61261 |  0:06:03s
epoch 47 | loss: 0.67645 | val_auc: 0.61577 |  0:06:11s
epoch 48 | loss: 0.67534 | val_auc: 0.61755 |  0:06:19s
epoch 49 | loss: 0.67701 | val_auc: 0.61101 |  0:06:28s
epoch 50 | loss: 0.67597 | val_auc: 0.61447 |  0:06:37s
epoch 51 | loss: 0.67624 | val_auc: 0.61677 |  0:06:46s
epoch 52 | loss: 0.67408 | val_auc: 0.61248 |  0:06:54s
epoch 53 | loss: 0.67487 | val_auc: 0.615   |  0:07:03s
epoch 54 | loss: 0.67807 | val_auc: 0.61561 |  0:07:11s
epoch 55 | loss: 0.67628 | val_auc: 0.61546 |  0:07:19s
epoch 56 | loss: 0.67576 | val_auc: 0.61431 |  0:07:27s
epoch 57 | loss: 0.67523 | val_auc: 0.61536 |  0:07:35s
epoch 58 | loss: 0.67759 | val_auc: 0.61397 |  0:07:43s
epoch 59 | loss: 0.67598 | val_auc: 0.61615 |  0:07:51s
epoch 60 | loss: 0.67776 | val_auc: 0.61799 |  0:07:59s
epoch 61 | loss: 0.67584 | val_auc: 0.61838 |  0:08:07s
epoch 62 | loss: 0.67573 | val_auc: 0.61738 |  0:08:15s
epoch 63 | loss: 0.67548 | val_auc: 0.62134 |  0:08:23s
epoch 64 | loss: 0.67541 | val_auc: 0.62044 |  0:08:31s
epoch 65 | loss: 0.67432 | val_auc: 0.62006 |  0:08:38s
epoch 66 | loss: 0.67386 | val_auc: 0.61963 |  0:08:46s
epoch 67 | loss: 0.67544 | val_auc: 0.6218  |  0:08:54s
epoch 68 | loss: 0.67277 | val_auc: 0.61955 |  0:09:02s
epoch 69 | loss: 0.6743  | val_auc: 0.62118 |  0:09:11s
epoch 70 | loss: 0.67474 | val_auc: 0.62311 |  0:09:19s
epoch 71 | loss: 0.67403 | val_auc: 0.62257 |  0:09:27s
epoch 72 | loss: 0.67199 | val_auc: 0.62108 |  0:09:35s
epoch 73 | loss: 0.67398 | val_auc: 0.6246  |  0:09:44s
epoch 74 | loss: 0.67331 | val_auc: 0.62247 |  0:09:52s
epoch 75 | loss: 0.67452 | val_auc: 0.62421 |  0:10:00s
epoch 76 | loss: 0.67449 | val_auc: 0.62342 |  0:10:09s
epoch 77 | loss: 0.67429 | val_auc: 0.61967 |  0:10:17s
epoch 78 | loss: 0.67434 | val_auc: 0.625   |  0:10:26s
epoch 79 | loss: 0.6737  | val_auc: 0.62474 |  0:10:34s
epoch 80 | loss: 0.67297 | val_auc: 0.62264 |  0:10:45s
epoch 81 | loss: 0.67493 | val_auc: 0.6265  |  0:10:55s
epoch 82 | loss: 0.67166 | val_auc: 0.62601 |  0:11:04s
epoch 83 | loss: 0.67239 | val_auc: 0.62398 |  0:11:14s
epoch 84 | loss: 0.67228 | val_auc: 0.62457 |  0:11:23s
epoch 85 | loss: 0.67285 | val_auc: 0.62422 |  0:11:32s
epoch 86 | loss: 0.67159 | val_auc: 0.62474 |  0:11:41s
epoch 87 | loss: 0.67553 | val_auc: 0.62966 |  0:11:51s
epoch 88 | loss: 0.67158 | val_auc: 0.62621 |  0:11:59s
epoch 89 | loss: 0.67214 | val_auc: 0.62551 |  0:12:08s
epoch 90 | loss: 0.67276 | val_auc: 0.62614 |  0:12:17s
epoch 91 | loss: 0.67125 | val_auc: 0.62688 |  0:12:28s
epoch 92 | loss: 0.67053 | val_auc: 0.62729 |  0:12:38s
epoch 93 | loss: 0.67084 | val_auc: 0.62689 |  0:12:47s
epoch 94 | loss: 0.67016 | val_auc: 0.62744 |  0:12:56s
epoch 95 | loss: 0.6711  | val_auc: 0.62668 |  0:13:07s
epoch 96 | loss: 0.67111 | val_auc: 0.62823 |  0:13:20s
epoch 97 | loss: 0.67122 | val_auc: 0.62452 |  0:13:30s
epoch 98 | loss: 0.67187 | val_auc: 0.62327 |  0:13:40s
epoch 99 | loss: 0.67294 | val_auc: 0.63076 |  0:13:49s
epoch 100| loss: 0.67022 | val_auc: 0.63146 |  0:13:59s
epoch 101| loss: 0.67218 | val_auc: 0.63074 |  0:14:08s
epoch 102| loss: 0.67101 | val_auc: 0.62925 |  0:14:18s
epoch 103| loss: 0.67286 | val_auc: 0.63147 |  0:14:29s
epoch 104| loss: 0.67242 | val_auc: 0.63113 |  0:14:39s
epoch 105| loss: 0.67051 | val_auc: 0.6318  |  0:14:48s
epoch 106| loss: 0.67121 | val_auc: 0.63181 |  0:14:57s
epoch 107| loss: 0.67174 | val_auc: 0.63271 |  0:15:07s
epoch 108| loss: 0.67015 | val_auc: 0.63342 |  0:15:16s
epoch 109| loss: 0.67167 | val_auc: 0.63363 |  0:15:25s
epoch 110| loss: 0.67037 | val_auc: 0.63235 |  0:15:35s
epoch 111| loss: 0.66924 | val_auc: 0.63267 |  0:15:44s
epoch 112| loss: 0.66941 | val_auc: 0.63205 |  0:15:53s
epoch 113| loss: 0.67201 | val_auc: 0.63283 |  0:16:03s
epoch 114| loss: 0.67087 | val_auc: 0.63305 |  0:16:13s
epoch 115| loss: 0.6702  | val_auc: 0.6352  |  0:16:23s
epoch 116| loss: 0.67071 | val_auc: 0.63337 |  0:16:33s
epoch 117| loss: 0.67049 | val_auc: 0.6306  |  0:16:42s
epoch 118| loss: 0.66964 | val_auc: 0.63496 |  0:16:52s
epoch 119| loss: 0.66975 | val_auc: 0.63274 |  0:17:02s
epoch 120| loss: 0.67164 | val_auc: 0.63474 |  0:17:14s
epoch 121| loss: 0.67245 | val_auc: 0.63343 |  0:17:23s
epoch 122| loss: 0.67056 | val_auc: 0.63268 |  0:17:34s
epoch 123| loss: 0.67062 | val_auc: 0.63371 |  0:17:45s
epoch 124| loss: 0.67194 | val_auc: 0.63116 |  0:17:56s
epoch 125| loss: 0.66845 | val_auc: 0.633   |  0:18:07s
epoch 126| loss: 0.66861 | val_auc: 0.63384 |  0:18:17s
epoch 127| loss: 0.67012 | val_auc: 0.63194 |  0:18:26s
epoch 128| loss: 0.67204 | val_auc: 0.62821 |  0:18:36s
epoch 129| loss: 0.67145 | val_auc: 0.63158 |  0:18:45s
epoch 130| loss: 0.67129 | val_auc: 0.63305 |  0:18:55s
epoch 131| loss: 0.67133 | val_auc: 0.6317  |  0:19:04s
epoch 132| loss: 0.6705  | val_auc: 0.63297 |  0:19:13s
epoch 133| loss: 0.67003 | val_auc: 0.63473 |  0:19:23s
epoch 134| loss: 0.67035 | val_auc: 0.63209 |  0:19:32s
epoch 135| loss: 0.67083 | val_auc: 0.63066 |  0:19:42s

Early stopping occurred at epoch 135 with best_epoch = 115 and best_val_auc = 0.6352
              precision    recall  f1-score   support

           0     0.7645    0.4866    0.5947      6790
           1     0.4054    0.7001    0.5135      3395

    accuracy                         0.5578     10185
   macro avg     0.5849    0.5934    0.5541     10185
weighted avg     0.6448    0.5578    0.5676     10185

Fold 2 AUC: 0.6352
MCC: 0.1781
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 3 ====
epoch 0  | loss: 0.7218  | val_auc: 0.55885 |  0:00:09s
epoch 1  | loss: 0.6882  | val_auc: 0.59358 |  0:00:19s
epoch 2  | loss: 0.68031 | val_auc: 0.59922 |  0:00:30s
epoch 3  | loss: 0.67972 | val_auc: 0.60477 |  0:00:40s
epoch 4  | loss: 0.67724 | val_auc: 0.61739 |  0:00:49s
epoch 5  | loss: 0.67683 | val_auc: 0.60908 |  0:00:58s
epoch 6  | loss: 0.67877 | val_auc: 0.59999 |  0:01:08s
epoch 7  | loss: 0.67584 | val_auc: 0.60837 |  0:01:17s
epoch 8  | loss: 0.67571 | val_auc: 0.60798 |  0:01:26s
epoch 9  | loss: 0.67433 | val_auc: 0.60073 |  0:01:36s
epoch 10 | loss: 0.67381 | val_auc: 0.61809 |  0:01:45s
epoch 11 | loss: 0.67232 | val_auc: 0.61666 |  0:01:55s
epoch 12 | loss: 0.67153 | val_auc: 0.61992 |  0:02:04s
epoch 13 | loss: 0.6725  | val_auc: 0.61924 |  0:02:13s
epoch 14 | loss: 0.67408 | val_auc: 0.61111 |  0:02:23s
epoch 15 | loss: 0.67444 | val_auc: 0.61399 |  0:02:33s
epoch 16 | loss: 0.67469 | val_auc: 0.61232 |  0:02:42s
epoch 17 | loss: 0.67318 | val_auc: 0.61808 |  0:02:52s
epoch 18 | loss: 0.67378 | val_auc: 0.62536 |  0:03:01s
epoch 19 | loss: 0.67363 | val_auc: 0.61995 |  0:03:11s
epoch 20 | loss: 0.6731  | val_auc: 0.62426 |  0:03:20s
epoch 21 | loss: 0.67163 | val_auc: 0.62367 |  0:03:30s
epoch 22 | loss: 0.6694  | val_auc: 0.63273 |  0:03:39s
epoch 23 | loss: 0.66939 | val_auc: 0.63642 |  0:03:48s
epoch 24 | loss: 0.66927 | val_auc: 0.63103 |  0:03:58s
epoch 25 | loss: 0.66685 | val_auc: 0.6307  |  0:04:07s
epoch 26 | loss: 0.66971 | val_auc: 0.62782 |  0:04:19s
epoch 27 | loss: 0.66645 | val_auc: 0.63065 |  0:04:30s
epoch 28 | loss: 0.66788 | val_auc: 0.6314  |  0:04:40s
epoch 29 | loss: 0.66949 | val_auc: 0.63063 |  0:04:50s
epoch 30 | loss: 0.66758 | val_auc: 0.63321 |  0:04:59s
epoch 31 | loss: 0.6698  | val_auc: 0.63539 |  0:05:09s
epoch 32 | loss: 0.66682 | val_auc: 0.63332 |  0:05:18s
epoch 33 | loss: 0.66664 | val_auc: 0.63623 |  0:05:28s
epoch 34 | loss: 0.66574 | val_auc: 0.64074 |  0:05:37s
epoch 35 | loss: 0.66666 | val_auc: 0.64068 |  0:05:46s
epoch 36 | loss: 0.6663  | val_auc: 0.63394 |  0:05:55s
epoch 37 | loss: 0.66692 | val_auc: 0.63794 |  0:06:04s
epoch 38 | loss: 0.66674 | val_auc: 0.63428 |  0:06:13s
epoch 39 | loss: 0.66658 | val_auc: 0.63457 |  0:06:22s
epoch 40 | loss: 0.66576 | val_auc: 0.63829 |  0:06:31s
epoch 41 | loss: 0.66497 | val_auc: 0.64095 |  0:06:39s
epoch 42 | loss: 0.66518 | val_auc: 0.64047 |  0:06:48s
epoch 43 | loss: 0.66518 | val_auc: 0.63705 |  0:06:56s
epoch 44 | loss: 0.66577 | val_auc: 0.63856 |  0:07:04s
epoch 45 | loss: 0.66607 | val_auc: 0.64073 |  0:07:14s
epoch 46 | loss: 0.66486 | val_auc: 0.638   |  0:07:22s
epoch 47 | loss: 0.66572 | val_auc: 0.64164 |  0:07:30s
epoch 48 | loss: 0.66593 | val_auc: 0.64362 |  0:07:38s
epoch 49 | loss: 0.66622 | val_auc: 0.63904 |  0:07:46s
epoch 50 | loss: 0.66453 | val_auc: 0.63709 |  0:07:54s
epoch 51 | loss: 0.66724 | val_auc: 0.64113 |  0:08:02s
epoch 52 | loss: 0.6647  | val_auc: 0.64107 |  0:08:09s
epoch 53 | loss: 0.66428 | val_auc: 0.63805 |  0:08:17s
epoch 54 | loss: 0.66469 | val_auc: 0.63981 |  0:08:26s
epoch 55 | loss: 0.66551 | val_auc: 0.6389  |  0:08:34s
epoch 56 | loss: 0.66675 | val_auc: 0.64354 |  0:08:42s
epoch 57 | loss: 0.66501 | val_auc: 0.63528 |  0:08:49s
epoch 58 | loss: 0.66621 | val_auc: 0.63981 |  0:08:57s
epoch 59 | loss: 0.66515 | val_auc: 0.64207 |  0:09:05s
epoch 60 | loss: 0.66393 | val_auc: 0.64397 |  0:09:13s
epoch 61 | loss: 0.66424 | val_auc: 0.64084 |  0:09:20s
epoch 62 | loss: 0.66511 | val_auc: 0.64313 |  0:09:28s
epoch 63 | loss: 0.66572 | val_auc: 0.64024 |  0:09:36s
epoch 64 | loss: 0.66487 | val_auc: 0.64332 |  0:09:44s
epoch 65 | loss: 0.66401 | val_auc: 0.64226 |  0:09:51s
epoch 66 | loss: 0.66402 | val_auc: 0.63947 |  0:09:59s
epoch 67 | loss: 0.66375 | val_auc: 0.64397 |  0:10:07s
epoch 68 | loss: 0.6645  | val_auc: 0.64077 |  0:10:15s
epoch 69 | loss: 0.66196 | val_auc: 0.64171 |  0:10:22s
epoch 70 | loss: 0.66386 | val_auc: 0.64178 |  0:10:31s
epoch 71 | loss: 0.66337 | val_auc: 0.64024 |  0:10:39s
epoch 72 | loss: 0.66184 | val_auc: 0.63714 |  0:10:47s
epoch 73 | loss: 0.66477 | val_auc: 0.63992 |  0:10:55s
epoch 74 | loss: 0.66346 | val_auc: 0.64097 |  0:11:03s
epoch 75 | loss: 0.66362 | val_auc: 0.63924 |  0:11:11s
epoch 76 | loss: 0.66526 | val_auc: 0.63897 |  0:11:19s
epoch 77 | loss: 0.66341 | val_auc: 0.64222 |  0:11:27s
epoch 78 | loss: 0.66433 | val_auc: 0.64289 |  0:11:35s
epoch 79 | loss: 0.6638  | val_auc: 0.64276 |  0:11:43s
epoch 80 | loss: 0.66278 | val_auc: 0.64011 |  0:11:50s

Early stopping occurred at epoch 80 with best_epoch = 60 and best_val_auc = 0.64397
              precision    recall  f1-score   support

           0     0.7683    0.5563    0.6453      6790
           1     0.4282    0.6645    0.5208      3395

    accuracy                         0.5923     10185
   macro avg     0.5982    0.6104    0.5830     10185
weighted avg     0.6549    0.5923    0.6038     10185

Fold 3 AUC: 0.6440
MCC: 0.2083
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 4 ====
epoch 0  | loss: 0.71392 | val_auc: 0.59768 |  0:00:08s
epoch 1  | loss: 0.68596 | val_auc: 0.59188 |  0:00:16s
epoch 2  | loss: 0.68274 | val_auc: 0.60029 |  0:00:26s
epoch 3  | loss: 0.67813 | val_auc: 0.60294 |  0:00:36s
epoch 4  | loss: 0.67905 | val_auc: 0.61214 |  0:00:44s
epoch 5  | loss: 0.67735 | val_auc: 0.61377 |  0:00:52s
epoch 6  | loss: 0.67779 | val_auc: 0.61148 |  0:01:00s
epoch 7  | loss: 0.67923 | val_auc: 0.60758 |  0:01:08s
epoch 8  | loss: 0.67706 | val_auc: 0.60488 |  0:01:16s
epoch 9  | loss: 0.67651 | val_auc: 0.60847 |  0:01:25s
epoch 10 | loss: 0.67599 | val_auc: 0.60629 |  0:01:33s
epoch 11 | loss: 0.67524 | val_auc: 0.6185  |  0:01:41s
epoch 12 | loss: 0.67545 | val_auc: 0.6192  |  0:01:49s
epoch 13 | loss: 0.67245 | val_auc: 0.62488 |  0:01:57s
epoch 14 | loss: 0.67228 | val_auc: 0.624   |  0:02:05s
epoch 15 | loss: 0.67096 | val_auc: 0.62234 |  0:02:13s
epoch 16 | loss: 0.67246 | val_auc: 0.62738 |  0:02:21s
epoch 17 | loss: 0.6703  | val_auc: 0.63416 |  0:02:29s
epoch 18 | loss: 0.66959 | val_auc: 0.62254 |  0:02:37s
epoch 19 | loss: 0.67049 | val_auc: 0.62396 |  0:02:45s
epoch 20 | loss: 0.66794 | val_auc: 0.63425 |  0:02:53s
epoch 21 | loss: 0.66636 | val_auc: 0.63345 |  0:03:01s
epoch 22 | loss: 0.66951 | val_auc: 0.63817 |  0:03:09s
epoch 23 | loss: 0.66621 | val_auc: 0.63856 |  0:03:17s
epoch 24 | loss: 0.66728 | val_auc: 0.6298  |  0:03:26s
epoch 25 | loss: 0.66853 | val_auc: 0.63756 |  0:03:37s
epoch 26 | loss: 0.66579 | val_auc: 0.64093 |  0:03:47s
epoch 27 | loss: 0.66558 | val_auc: 0.63392 |  0:03:55s
epoch 28 | loss: 0.66841 | val_auc: 0.62935 |  0:04:04s
epoch 29 | loss: 0.66533 | val_auc: 0.63438 |  0:04:12s
epoch 30 | loss: 0.66557 | val_auc: 0.63502 |  0:08:11s
epoch 31 | loss: 0.6642  | val_auc: 0.63649 |  0:08:19s
epoch 32 | loss: 0.66306 | val_auc: 0.63905 |  0:08:26s
epoch 33 | loss: 0.66493 | val_auc: 0.63957 |  0:21:32s
epoch 34 | loss: 0.66514 | val_auc: 0.64148 |  0:21:40s
epoch 35 | loss: 0.66407 | val_auc: 0.63754 |  0:21:46s
epoch 36 | loss: 0.66367 | val_auc: 0.64246 |  0:21:53s
epoch 37 | loss: 0.66323 | val_auc: 0.63896 |  0:21:59s
epoch 38 | loss: 0.66421 | val_auc: 0.64307 |  0:22:06s
epoch 39 | loss: 0.66146 | val_auc: 0.64007 |  0:22:13s
epoch 40 | loss: 0.66237 | val_auc: 0.64468 |  0:41:44s
epoch 41 | loss: 0.66187 | val_auc: 0.64453 |  0:41:56s
epoch 42 | loss: 0.65994 | val_auc: 0.64117 |  0:42:04s
epoch 43 | loss: 0.65739 | val_auc: 0.63615 |  0:42:12s
epoch 44 | loss: 0.65858 | val_auc: 0.64934 |  0:42:21s
epoch 45 | loss: 0.66051 | val_auc: 0.64757 |  0:42:27s
epoch 46 | loss: 0.65543 | val_auc: 0.64908 |  0:42:34s
epoch 47 | loss: 0.65763 | val_auc: 0.6431  |  0:42:40s
epoch 48 | loss: 0.65457 | val_auc: 0.64378 |  0:42:46s
epoch 49 | loss: 0.6593  | val_auc: 0.64737 |  0:42:53s
epoch 50 | loss: 0.65383 | val_auc: 0.65483 |  0:42:59s
epoch 51 | loss: 0.65388 | val_auc: 0.65875 |  0:43:06s
epoch 52 | loss: 0.65233 | val_auc: 0.65741 |  0:43:13s
epoch 53 | loss: 0.64907 | val_auc: 0.66099 |  0:43:19s
epoch 54 | loss: 0.64801 | val_auc: 0.66493 |  0:43:26s
epoch 55 | loss: 0.64823 | val_auc: 0.65804 |  0:43:34s
epoch 56 | loss: 0.64292 | val_auc: 0.66288 |  0:43:41s
epoch 57 | loss: 0.64196 | val_auc: 0.67365 |  0:43:48s
epoch 58 | loss: 0.6332  | val_auc: 0.66825 |  0:43:55s
epoch 59 | loss: 0.62525 | val_auc: 0.68516 |  0:44:02s
epoch 60 | loss: 0.61705 | val_auc: 0.69033 |  0:44:09s
epoch 61 | loss: 0.60889 | val_auc: 0.71261 |  0:44:17s
epoch 62 | loss: 0.59198 | val_auc: 0.72509 |  0:44:24s
epoch 63 | loss: 0.58121 | val_auc: 0.73198 |  0:44:31s
epoch 64 | loss: 0.57331 | val_auc: 0.7366  |  0:44:38s
epoch 65 | loss: 0.56609 | val_auc: 0.68992 |  0:44:45s
epoch 66 | loss: 0.55844 | val_auc: 0.71954 |  0:44:53s
epoch 67 | loss: 0.55355 | val_auc: 0.72913 |  0:45:00s
epoch 68 | loss: 0.54844 | val_auc: 0.72841 |  0:45:07s
epoch 69 | loss: 0.54431 | val_auc: 0.733   |  0:45:14s
epoch 70 | loss: 0.53927 | val_auc: 0.71739 |  0:45:22s
epoch 71 | loss: 0.5337  | val_auc: 0.73548 |  0:45:29s
epoch 72 | loss: 0.53193 | val_auc: 0.70763 |  0:45:36s
epoch 73 | loss: 0.53044 | val_auc: 0.71388 |  0:45:43s
epoch 74 | loss: 0.52454 | val_auc: 0.73895 |  0:45:51s
epoch 75 | loss: 0.5225  | val_auc: 0.73833 |  0:45:58s
epoch 76 | loss: 0.51702 | val_auc: 0.74352 |  0:46:05s
epoch 77 | loss: 0.52091 | val_auc: 0.72115 |  0:46:13s
epoch 78 | loss: 0.51458 | val_auc: 0.73361 |  0:46:20s
epoch 79 | loss: 0.51612 | val_auc: 0.71282 |  0:46:28s
epoch 80 | loss: 0.51007 | val_auc: 0.74158 |  0:46:35s
epoch 81 | loss: 0.51016 | val_auc: 0.73994 |  0:46:43s
epoch 82 | loss: 0.50892 | val_auc: 0.65165 |  0:46:50s
epoch 83 | loss: 0.51212 | val_auc: 0.73442 |  0:46:57s
epoch 84 | loss: 0.50758 | val_auc: 0.734   |  0:47:05s
epoch 85 | loss: 0.50531 | val_auc: 0.74106 |  0:47:12s
epoch 86 | loss: 0.50151 | val_auc: 0.71251 |  0:47:20s
epoch 87 | loss: 0.50668 | val_auc: 0.72136 |  0:47:27s
epoch 88 | loss: 0.50778 | val_auc: 0.73639 |  0:47:35s
epoch 89 | loss: 0.50359 | val_auc: 0.72468 |  0:47:44s
epoch 90 | loss: 0.50021 | val_auc: 0.72629 |  0:47:52s
epoch 91 | loss: 0.50154 | val_auc: 0.72505 |  0:47:59s
epoch 92 | loss: 0.5002  | val_auc: 0.74648 |  0:48:07s
epoch 93 | loss: 0.49619 | val_auc: 0.72082 |  0:48:14s
epoch 94 | loss: 0.49808 | val_auc: 0.70147 |  0:48:22s
epoch 95 | loss: 0.49894 | val_auc: 0.7272  |  0:48:29s
epoch 96 | loss: 0.49718 | val_auc: 0.73616 |  0:48:37s
epoch 97 | loss: 0.49433 | val_auc: 0.75113 |  0:48:44s
epoch 98 | loss: 0.49574 | val_auc: 0.72391 |  0:48:52s
epoch 99 | loss: 0.49716 | val_auc: 0.69928 |  0:48:59s
epoch 100| loss: 0.49181 | val_auc: 0.73015 |  0:49:07s
epoch 101| loss: 0.49825 | val_auc: 0.70603 |  0:49:14s
epoch 102| loss: 0.49114 | val_auc: 0.75679 |  0:49:22s
epoch 103| loss: 0.49518 | val_auc: 0.7136  |  0:49:30s
epoch 104| loss: 0.49372 | val_auc: 0.72138 |  0:49:37s
epoch 105| loss: 0.49357 | val_auc: 0.7604  |  0:49:45s
epoch 106| loss: 0.49471 | val_auc: 0.75576 |  0:49:53s
epoch 107| loss: 0.4902  | val_auc: 0.68895 |  0:50:00s
epoch 108| loss: 0.49344 | val_auc: 0.49985 |  0:50:08s
epoch 109| loss: 0.4924  | val_auc: 0.75216 |  0:50:16s
epoch 110| loss: 0.48773 | val_auc: 0.75372 |  0:50:23s
epoch 111| loss: 0.48885 | val_auc: 0.75479 |  0:50:31s
epoch 112| loss: 0.4902  | val_auc: 0.55466 |  0:50:39s
epoch 113| loss: 0.48912 | val_auc: 0.73351 |  0:50:46s
epoch 114| loss: 0.48739 | val_auc: 0.72823 |  0:50:54s
epoch 115| loss: 0.48659 | val_auc: 0.75073 |  0:51:02s
epoch 116| loss: 0.48825 | val_auc: 0.7511  |  0:51:10s
epoch 117| loss: 0.4894  | val_auc: 0.68768 |  0:51:18s
epoch 118| loss: 0.48623 | val_auc: 0.69732 |  0:51:26s
epoch 119| loss: 0.4896  | val_auc: 0.74797 |  0:51:34s
epoch 120| loss: 0.48947 | val_auc: 0.75605 |  0:51:43s
epoch 121| loss: 0.48622 | val_auc: 0.74973 |  0:51:51s
epoch 122| loss: 0.48344 | val_auc: 0.71532 |  0:51:58s
epoch 123| loss: 0.48918 | val_auc: 0.73061 |  0:52:06s
epoch 124| loss: 0.48184 | val_auc: 0.64659 |  0:52:14s
epoch 125| loss: 0.48205 | val_auc: 0.74083 |  0:52:22s

Early stopping occurred at epoch 125 with best_epoch = 105 and best_val_auc = 0.7604
              precision    recall  f1-score   support

           0     0.7906    1.0000    0.8831      6790
           1     1.0000    0.4702    0.6397      3394

    accuracy                         0.8234     10184
   macro avg     0.8953    0.7351    0.7614     10184
weighted avg     0.8604    0.8234    0.8020     10184

Fold 4 AUC: 0.7604
MCC: 0.6097
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 5 ====
epoch 0  | loss: 0.71305 | val_auc: 0.57825 |  0:00:07s
epoch 1  | loss: 0.68662 | val_auc: 0.58927 |  0:00:15s
epoch 2  | loss: 0.68255 | val_auc: 0.60638 |  0:00:23s
epoch 3  | loss: 0.67729 | val_auc: 0.6091  |  0:00:31s
epoch 4  | loss: 0.67632 | val_auc: 0.61426 |  0:00:39s
epoch 5  | loss: 0.67503 | val_auc: 0.6169  |  0:00:46s
epoch 6  | loss: 0.67486 | val_auc: 0.61544 |  0:00:54s
epoch 7  | loss: 0.6751  | val_auc: 0.61394 |  0:01:02s
epoch 8  | loss: 0.67459 | val_auc: 0.61279 |  0:01:10s
epoch 9  | loss: 0.67463 | val_auc: 0.61803 |  0:01:18s
epoch 10 | loss: 0.67213 | val_auc: 0.62082 |  0:01:25s
epoch 11 | loss: 0.67261 | val_auc: 0.61999 |  0:01:33s
epoch 12 | loss: 0.67108 | val_auc: 0.61774 |  0:01:41s
epoch 13 | loss: 0.66813 | val_auc: 0.62669 |  0:01:48s
epoch 14 | loss: 0.66755 | val_auc: 0.62806 |  0:01:56s
epoch 15 | loss: 0.66766 | val_auc: 0.63502 |  0:02:04s
epoch 16 | loss: 0.66792 | val_auc: 0.6349  |  0:02:12s
epoch 17 | loss: 0.66508 | val_auc: 0.63258 |  0:02:20s
epoch 18 | loss: 0.66732 | val_auc: 0.63808 |  0:02:28s
epoch 19 | loss: 0.66639 | val_auc: 0.62917 |  0:02:36s
epoch 20 | loss: 0.66791 | val_auc: 0.63797 |  0:02:43s
epoch 21 | loss: 0.66584 | val_auc: 0.63302 |  0:02:51s
epoch 22 | loss: 0.66386 | val_auc: 0.63666 |  0:02:59s
epoch 23 | loss: 0.66522 | val_auc: 0.64241 |  0:03:07s
epoch 24 | loss: 0.66441 | val_auc: 0.63817 |  0:03:16s
epoch 25 | loss: 0.66422 | val_auc: 0.63368 |  0:03:25s
epoch 26 | loss: 0.66616 | val_auc: 0.63348 |  0:03:34s
epoch 27 | loss: 0.66565 | val_auc: 0.63597 |  0:03:42s
epoch 28 | loss: 0.66532 | val_auc: 0.63909 |  0:03:50s
epoch 29 | loss: 0.66665 | val_auc: 0.63869 |  0:03:57s
epoch 30 | loss: 0.66418 | val_auc: 0.63934 |  0:04:06s
epoch 31 | loss: 0.66443 | val_auc: 0.64022 |  0:04:14s
epoch 32 | loss: 0.6626  | val_auc: 0.64237 |  0:04:22s
epoch 33 | loss: 0.66573 | val_auc: 0.63978 |  0:04:31s
epoch 34 | loss: 0.66166 | val_auc: 0.64345 |  0:04:39s
epoch 35 | loss: 0.66483 | val_auc: 0.64611 |  0:04:48s
epoch 36 | loss: 0.66584 | val_auc: 0.64238 |  0:04:56s
epoch 37 | loss: 0.66269 | val_auc: 0.64274 |  0:05:05s
epoch 38 | loss: 0.66112 | val_auc: 0.64267 |  0:05:14s
epoch 39 | loss: 0.66537 | val_auc: 0.6401  |  0:05:23s
epoch 40 | loss: 0.66322 | val_auc: 0.64375 |  0:05:32s
epoch 41 | loss: 0.66461 | val_auc: 0.64253 |  0:05:41s
epoch 42 | loss: 0.66167 | val_auc: 0.64271 |  0:05:50s
epoch 43 | loss: 0.66472 | val_auc: 0.64603 |  0:05:59s
epoch 44 | loss: 0.66321 | val_auc: 0.64158 |  0:06:08s
epoch 45 | loss: 0.6622  | val_auc: 0.64519 |  0:06:17s
epoch 46 | loss: 0.66451 | val_auc: 0.64448 |  0:06:27s
epoch 47 | loss: 0.66212 | val_auc: 0.64417 |  0:06:36s
epoch 48 | loss: 0.66134 | val_auc: 0.64511 |  0:06:45s
epoch 49 | loss: 0.6631  | val_auc: 0.6443  |  0:06:54s
epoch 50 | loss: 0.66151 | val_auc: 0.64272 |  0:07:03s
epoch 51 | loss: 0.6633  | val_auc: 0.64404 |  0:07:13s
epoch 52 | loss: 0.66263 | val_auc: 0.64297 |  0:07:22s
epoch 53 | loss: 0.66179 | val_auc: 0.6436  |  0:07:31s
epoch 54 | loss: 0.66498 | val_auc: 0.64279 |  0:07:41s
epoch 55 | loss: 0.66353 | val_auc: 0.64476 |  0:07:50s

Early stopping occurred at epoch 55 with best_epoch = 35 and best_val_auc = 0.64611
              precision    recall  f1-score   support

           0     0.7645    0.5733    0.6552      6789
           1     0.4312    0.6468    0.5174      3395

    accuracy                         0.5978     10184
   macro avg     0.5978    0.6101    0.5863     10184
weighted avg     0.6534    0.5978    0.6093     10184

Fold 5 AUC: 0.6461
MCC: 0.2075
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Cross-Validation Complete ====
Mean AUC: 0.6917 | Std AUC: 0.0614
Successfully saved model at Failure_Event_tabnet_model.zip
Out[105]:
'Failure_Event_tabnet_model.zip'

TAB NET¶

Tab net is a NN that is specifically designed for tabular data.

  • Traditional deep learning models (like MLPs) struggle with tabular data. TabNet, however, was specifically designed for structured datasets — which have usually been better handled by models like XGBoost or Random Forests.
  • TabNet processes data in steps, and at each step, it uses an attention mechanism to decide
    • Which features to focus on
    • How much each feature should influence the prediction
    • This is very different from tree-based models that use greedy splits or MLPs that treat all input features equally at all times.
  • TabNet promotes sparsity in feature usage
    • Each decision step uses only a small subset of features.
    • This makes the model interpretable and efficient
    • This is controlled by the gamma parameter and the entmax activation — which encourages attention to only a few important inputs.
  • Handles Categorical Variables Natively
    • Uses embeddings for categorical variables (like NLP models)
    • No need for manual one-hot encoding or label encoding.
    • Learns better representations for categories during training.
  • TabNet is trained end-to-end using gradient descent, which makes it
    • More flexible and scalable
    • Capable of benefiting from powerful optimization tools like learning rate schedulers, early stopping, etc.
  • Comparision of models
Feature TabNet XGBoost / Random Forest MLP (Feedforward Neural Net)
Feature Selection Attention-based, sparse Tree splits Implicit (weights)
Interpretability High (built-in feature masks) Medium (requires SHAP/LIME) Low
Categorical Handling Native (embeddings) Manual encoding (one-hot/label) Manual encoding (one-hot/label)
Sequential Decision Steps Yes No No
Handles Imbalanced Data Yes (class weights + attention) Yes (with tuning) Needs balancing (e.g., SMOTE)
Training End-to-end gradient descent Gradient boosting / bagging End-to-end gradient descent
Performance on Tabular Very competitive Strong baseline Often underperforms

TAB NET specific parameters¶

 clf_F = TabNetClassifier(
        n_d=32, n_a=32, n_steps=5, gamma=1.5,
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=cat_emb_dim,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=1e-2),
        scheduler_params={"step_size":10, "gamma":0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='entmax',
        seed=RANDOM_STATE,
        verbose=1
    ) 
  • n_d=32: Number of dimensions for the decision step output. Controls the size of the vector that holds learned representations at each step.
  • n_a=32: Number of dimensions for the attention step output. Controls how much attention each feature gets during selection. Usually set equal to n_d.
  • n_steps=5: Number of sequential decision steps. Each step decides which features to focus on and learns new representations. More steps allow more complex reasoning but increase computation.
  • gamma=1.5: Controls the sparsity of feature selection. Higher gamma results in more sparse attention (focus on fewer features). This enhances interpretability and helps regularize the model.
  • cat_idxs=cat_idxs: List of column indices for categorical features in the dataset. Tells TabNet which features require embedding.
  • cat_dims=cat_dims: List containing the number of unique values for each categorical column. Used to define embedding layer dimensions.
  • cat_emb_dim=cat_emb_dim: List of embedding dimensions for each categorical feature. Typically set using a rule like min(50, (cat_dim + 1) // 2).
  • optimizer_fn=torch.optim.Adam: Specifies the optimizer to use. Adam optimizer is used here for its efficiency and adaptive learning rate.
  • optimizer_params=dict(lr=1e-2): Sets the initial learning rate to 0.01.
  • scheduler_params={"step_size":10, "gamma":0.9}: Learning rate scheduler parameters. Every 10 epochs, the learning rate is multiplied by 0.9 to gradually reduce it.
  • scheduler_fn=torch.optim.lr_scheduler.StepLR: The function used to reduce the learning rate at fixed intervals. StepLR is a simple and commonly used scheduler.
  • mask_type='entmax': Specifies the attention mask type for feature selection. entmax creates sparse masks, allowing the model to focus only on the most relevant features and ignore the rest (assigned 0 attention), improving interpretability.
clf_F.fit(
        X_train=TAB_X_train.values, y_train=TAB_y_train.values,
        eval_set=[(TAB_X_val.values, TAB_y_val.values)],
        eval_name=['val'],
        eval_metric=['auc'],
        max_epochs=EPOCHS,
        patience=20,
        batch_size=BATCH_SIZE,
        virtual_batch_size=128,
        weights = weights
    )

TabNet fit() Parameter Explanation¶

  • X_train=TAB_X_train.values: The training feature matrix as a NumPy array.

  • y_train=TAB_y_train.values: The target labels for training as a NumPy array.

  • eval_set=[(TAB_X_val.values, TAB_y_val.values)]: A list of tuples containing the validation set (features and labels). Used to monitor model performance during training.

  • eval_name=['val']: The name associated with the evaluation set. Appears in the training logs to identify the validation metrics.

  • eval_metric=['auc']: The evaluation metric to monitor. In this case, Area Under the ROC Curve (AUC) is used to track model performance.

  • max_epochs=EPOCHS: The maximum number of epochs to train the model. Here, it's set using the EPOCHS constant (e.g., 200).

  • patience=20: Early stopping parameter. Training will stop if the validation metric does not improve for 20 consecutive epochs.

  • batch_size=BATCH_SIZE: Number of samples processed in each training batch. Controlled via the BATCH_SIZE variable (e.g., 64).

  • virtual_batch_size=128: Used for Ghost Batch Normalization. Enables batch normalization over smaller subsets of the batch to simulate smaller batch behavior, improving generalization and stability.

  • weights=weights: Class weights used to handle class imbalance. Ensures the model pays appropriate attention to minority classes by penalizing misclassification more heavily.

Best accuracy produced in Fold 4¶

  • Accuracy: 0.8234
  • AUC: 0.7604
  • MCC: 0.6097

For evaluation of model metric and plots look into the report

Evaluation on Untrained data¶

In [106]:
# Predict probabilities and labels
# Predict
TAB_x = X_val.values
TAB_y = y_val

y_pred_proba = clf_F.predict_proba(TAB_x)[:, 1]
y_pred = clf_F.predict(TAB_x)

# AUC
auc_score = roc_auc_score(TAB_y, y_pred_proba)
print(f"\nFinal Validation AUC: {auc_score:.4f}")

# Classification report
print("Classification Report:")
print(classification_report(TAB_y, y_pred, digits=4))

# MCC
mcc = matthews_corrcoef(TAB_y, y_pred)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")

# Confusion matrix
cm = confusion_matrix(TAB_y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()

from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(TAB_y, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()

from sklearn.metrics import precision_recall_curve, average_precision_score

precision, recall, _ = precision_recall_curve(TAB_y, y_pred_proba)
avg_precision = average_precision_score(TAB_y, y_pred_proba)

plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_pred_proba[TAB_y == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()

import seaborn as sns

# Explain method to get global feature importance
explain_matrix, masks = clf_F.explain(TAB_x)  # TAB_x should be your validation features

# Create DataFrame for feature importances
feature_importances_df = pd.DataFrame({
    'Feature': X_val.columns,
    'Importance': explain_matrix.sum(axis=0)
}).sort_values(by='Importance', ascending=False)

# Plot
plt.figure(figsize=(8, 6))
sns.barplot(data=feature_importances_df, x='Importance', y='Feature', palette='viridis')
plt.title("TabNet Feature Importances (Global Explanation)")
plt.tight_layout()
plt.show()
Final Validation AUC: 0.5075
Classification Report:
              precision    recall  f1-score   support

           0     0.8522    0.5678    0.6815      8487
           1     0.1558    0.4475    0.2311      1513

    accuracy                         0.5496     10000
   macro avg     0.5040    0.5076    0.4563     10000
weighted avg     0.7468    0.5496    0.6134     10000

Matthews Correlation Coefficient (MCC): 0.0110
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Predicted probability histogram¶

  • Class 0 (No Failure) predictions are mostly centered around 0.4–0.6.
  • Class 1 (Failure) predictions overlap with Class 0 and also tend to peak around 0.4–0.5.
  • This overlap indicates that the model is:
  • Not highly confident in distinguishing failures from non-failures.
  • Producing probabilities clustered near the middle range instead of near 0 or 1.

Feature importance¶

  • Vibration_Levels has the highest importance. This suggests it's the most predictive feature for the Failure_Event target.
  • Fuel_Consumption, Humidity, and Age_of_Asset also have high importance, indicating strong influence on failure predictions.
  • Moderately useful features:
    • Pressure has some influence but significantly less than the top four.
  • Least used features:
    • Features like Thermal_Stress, Temperature, Fuel_Efficiency, Asset_Type, Maintenance_History, and others at the bottom have very low to negligible importance in this model. It doesn't mean they are useless — just that TabNet didn't find them as helpful in the current data context.

Needs Maintainence¶

Reprepping Data¶

In [107]:
df["needs_maintenance"] = (df["Maintenance_History"] > 0).astype(int)
In [108]:
df.value_counts('needs_maintenance')
Out[108]:
needs_maintenance
1    33303
0    16697
Name: count, dtype: int64
In [109]:
df.value_counts('Maintenance_History')
Out[109]:
Maintenance_History
2    16714
0    16697
1    16589
Name: count, dtype: int64
In [110]:
df[["Asset_Type","needs_maintenance"]].value_counts()
Out[110]:
Asset_Type  needs_maintenance
2           1                    11167
1           1                    11069
0           1                    11067
1           0                     5640
0           0                     5542
2           0                     5515
Name: count, dtype: int64

Data Pre processing¶

In [111]:
# Features and target
X = df.drop(columns=['Maintenance_History','needs_maintenance'])
y = df['needs_maintenance']

# Define categorical and numerical features
cat_features = ['Asset_Type', 'Location']
num_features = [col for col in X.columns if col not in cat_features]

# Train-test split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# --- Apply SMOTE to balance the classes in the training data ---
categorical_indices = [X_train.columns.get_loc(col) for col in cat_features]

print("Before SMOTE:", Counter(y_train))

# Apply SMOTENC (for categorical and continuous features)
smote_nc = SMOTENC(
    categorical_features=categorical_indices,
    sampling_strategy=0.7,  # Make minority class 70% the size of the majority class
    random_state=42
)

X_train_balanced, y_train_balanced = smote_nc.fit_resample(X_train, y_train)

# Convert back to DataFrame for further processing
X_train_final = pd.DataFrame(X_train_balanced, columns=X_train.columns)
y_train_final = pd.Series(y_train_balanced, name='needs_maintenance')

print("After SMOTE:", Counter(y_train_balanced))

# --- Scale numeric features after SMOTE ---
scaler = StandardScaler()

# Separate numerical columns for scaling
X_train_num = X_train_final[num_features]
X_val_num = X_val[num_features]

X_train_num_scaled = scaler.fit_transform(X_train_num)
X_val_num_scaled = scaler.transform(X_val_num)

# Convert back to DataFrame
X_train_num_scaled_df = pd.DataFrame(X_train_num_scaled, columns=num_features).reset_index(drop=True)
X_val_num_scaled_df = pd.DataFrame(X_val_num_scaled, columns=num_features).reset_index(drop=True)

# Prepare final training and validation sets by combining scaled numeric and encoded categorical features
X_train_final = pd.concat([X_train_num_scaled_df, X_train_final[cat_features].reset_index(drop=True)], axis=1)
X_val_final = pd.concat([X_val_num_scaled_df, X_val[cat_features].reset_index(drop=True)], axis=1)
Before SMOTE: Counter({1: 26642, 0: 13358})
After SMOTE: Counter({1: 26642, 0: 18649})

Random Forest Model¶

Grid search for best parameters¶

In [112]:
# STEP 2: Set up a Random Forest with basic pruning options
rf = RandomForestClassifier(random_state=42, n_jobs=-1, class_weight='balanced')

# STEP 3: Grid Search for best hyperparameters including pruning-related ones
param_grid = {
    'n_estimators': [100],
    'max_depth': [3, 5, 7, None],  # Control tree size (pruning)
    'min_samples_split': [2, 5, 10],  # Prevent overgrowth
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2']
}

grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=3, scoring='roc_auc', verbose=1)
grid_search.fit(X_train_balanced, y_train_balanced) # using unscalled data for Random Forest

# STEP 4: Get best model and evaluate
best_rf_M = grid_search.best_estimator_
y_pred = best_rf_M.predict(X_val)
y_prob = best_rf_M.predict_proba(X_val)[:, 1]

print("Best Params:", grid_search.best_params_)
print("AUC-ROC:", roc_auc_score(y_val, y_prob))
print(classification_report(y_val, y_pred, digits=4))
mcc = matthews_corrcoef(y_val, y_pred)
print(f"MCC: {mcc:.4f}")

# STEP 5: Plot one of the trees in the forest
plt.figure(figsize=(30, 20))
plot_tree(best_rf_M.estimators_[0], 
          feature_names=X.columns, 
          class_names=['Class 0', 'Class 1'], 
          filled=True, 
          rounded=True,
          max_depth=3,
          fontsize=10)  # Only show top 3 levels for clarity
plt.title("Random Forest - Tree Visualization")
plt.show()

joblib.dump(best_rf_M, 'maintainence_needs_random_forest_model.joblib')
Fitting 3 folds for each of 72 candidates, totalling 216 fits
Best Params: {'max_depth': None, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}
AUC-ROC: 0.5045436419698883
              precision    recall  f1-score   support

           0     0.3440    0.1518    0.2107      3339
           1     0.6678    0.8548    0.7499      6661

    accuracy                         0.6201     10000
   macro avg     0.5059    0.5033    0.4803     10000
weighted avg     0.5597    0.6201    0.5698     10000

MCC: 0.0089
No description has been provided for this image
Out[112]:
['maintainence_needs_random_forest_model.joblib']

K-fold on best model¶

In [113]:
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import (
    roc_auc_score,
    classification_report,
    matthews_corrcoef,
    confusion_matrix,
    ConfusionMatrixDisplay,
    precision_recall_curve,
    auc  # re-importing packages so I can use it without any error
)

# Define Stratified K-Fold
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Store metrics for summary
roc_auc_scores = []

# Loop over each fold
for fold, (train_idx, val_idx) in enumerate(cv.split(X_train_balanced, y_train_balanced), start=1):
    X_train_fold, X_val_fold = X_train_balanced.iloc[train_idx], X_train_balanced.iloc[val_idx]
    y_train_fold, y_val_fold = y_train_balanced.iloc[train_idx], y_train_balanced.iloc[val_idx]
    
    # Train the best model on this fold
    best_rf_M.fit(X_train_fold, y_train_fold)
    y_pred = best_rf_M.predict(X_val_fold)
    y_proba = best_rf_M.predict_proba(X_val_fold)[:, 1]
    
    # AUC for this fold
    roc_auc = roc_auc_score(y_val_fold, y_proba)
    roc_auc_scores.append(roc_auc)
    
    # Print classification report
    print(f"\nFold {fold} - Classification Report:")
    print(classification_report(y_val_fold, y_pred, digits=4))
    print(f"Fold {fold} - AUC-ROC: {roc_auc:.4f}")
    mcc = matthews_corrcoef(y_val_fold, y_pred)
    print(f"MCC: {mcc:.4f}")
    # Confusion Matrix
    cm = confusion_matrix(y_val_fold, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='Blues')
    plt.title("Confusion Matrix")
    plt.show()

    # Compute ROC curve and AUC
    fpr, tpr, thresholds = roc_curve(y_val_fold, y_proba)
    roc_auc_curve = auc(fpr, tpr)
    
    # historgraph of predicted probabilities
    plt.figure(figsize=(6, 4))
    plt.hist(y_proba[y_val_fold == 0], bins=30, alpha=0.6, label='Class 0 (Does not require)', color='skyblue')
    plt.hist(y_proba[y_val_fold == 1], bins=30, alpha=0.6, label='Class 1 (Needs_Maintainence)', color='salmon')
    plt.xlabel('Predicted Probability')
    plt.ylabel('Count')
    plt.title('Histogram of Predicted Probabilities by Class')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()

    # Plot ROC Curve
    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f"ROC AUC = {roc_auc:.4f}")
    plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')  # Diagonal line
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

# Print average AUC over all folds
print("\nAverage AUC-ROC across folds:")
print(f"{np.mean(roc_auc_scores):.4f} ± {np.std(roc_auc_scores):.4f}")
Fold 1 - Classification Report:
              precision    recall  f1-score   support

           0     0.5921    0.3051    0.4027      3730
           1     0.6368    0.8529    0.7292      5329

    accuracy                         0.6273      9059
   macro avg     0.6145    0.5790    0.5659      9059
weighted avg     0.6184    0.6273    0.5948      9059

Fold 1 - AUC-ROC: 0.6146
MCC: 0.1902
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 2 - Classification Report:
              precision    recall  f1-score   support

           0     0.5830    0.3138    0.4079      3729
           1     0.6371    0.8429    0.7257      5329

    accuracy                         0.6251      9058
   macro avg     0.6100    0.5783    0.5668      9058
weighted avg     0.6148    0.6251    0.5949      9058

Fold 2 - AUC-ROC: 0.6106
MCC: 0.1857
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 3 - Classification Report:
              precision    recall  f1-score   support

           0     0.5851    0.2968    0.3938      3730
           1     0.6340    0.8527    0.7272      5328

    accuracy                         0.6238      9058
   macro avg     0.6095    0.5747    0.5605      9058
weighted avg     0.6138    0.6238    0.5899      9058

Fold 3 - AUC-ROC: 0.6084
MCC: 0.1809
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 4 - Classification Report:
              precision    recall  f1-score   support

           0     0.5976    0.3118    0.4098      3730
           1     0.6391    0.8530    0.7307      5328

    accuracy                         0.6302      9058
   macro avg     0.6183    0.5824    0.5703      9058
weighted avg     0.6220    0.6302    0.5986      9058

Fold 4 - AUC-ROC: 0.6091
MCC: 0.1975
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Fold 5 - Classification Report:
              precision    recall  f1-score   support

           0     0.5812    0.2965    0.3927      3730
           1     0.6333    0.8504    0.7259      5328

    accuracy                         0.6223      9058
   macro avg     0.6072    0.5735    0.5593      9058
weighted avg     0.6118    0.6223    0.5887      9058

Fold 5 - AUC-ROC: 0.6113
MCC: 0.1775
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Average AUC-ROC across folds:
0.6108 ± 0.0022

TAB NET¶

In [114]:
# Configuration
TARGET_COL = 'Failure_Event'
CATEGORICAL_COLS = ['Asset_Type', 'Location']
RANDOM_STATE = 42
N_SPLITS = 5
EPOCHS = 100
BATCH_SIZE = 64

# Identify categorical feature indices and dimensions
cat_idxs = [X_train_balanced.columns.get_loc(col) for col in CATEGORICAL_COLS]
cat_dims = [int(df[col].nunique()) for col in CATEGORICAL_COLS]
cat_emb_dim = [min(50, (dim + 1) // 2) for dim in cat_dims]

# Cross-validation setup
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

fold = 1
auc_scores = []
# Compute class weights for this fold
classes = np.unique(y_train_balanced)
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train_balanced)
weights = dict(zip(classes, class_weights))

for train_idx, val_idx in skf.split(X_train_balanced, y_train_balanced):
    print(f"\n==== Fold {fold} ====")
    
    TAB_X_train, TAB_X_val = X_train_balanced.values[train_idx], X_train_balanced.values[val_idx]
    TAB_y_train, TAB_y_val = y_train_balanced.values[train_idx], y_train_balanced.values[val_idx]

    # Initialize and train the model
    clf_M = TabNetClassifier(
        n_d=32, n_a=32, n_steps=5, gamma=1.5,
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=cat_emb_dim,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=1e-2),
        scheduler_params={"step_size":10, "gamma":0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='entmax',
        seed=RANDOM_STATE,
        verbose=1
    )

    clf_M.fit(
        X_train=TAB_X_train, y_train=TAB_y_train,
        eval_set=[(TAB_X_val, TAB_y_val)],
        eval_name=['val'],
        eval_metric=['auc'],
        max_epochs=EPOCHS,
        patience=20,
        batch_size=BATCH_SIZE,
        virtual_batch_size=128,
        weights=weights
    )

    # Evaluation
    y_pred_proba = clf_M.predict_proba(TAB_X_val)[:, 1]
    y_pred = clf_M.predict(TAB_X_val)
    auc = roc_auc_score(TAB_y_val, y_pred_proba)
    auc_scores.append(auc)
    
    print(f"Fold {fold} AUC: {auc:.4f}")
    print(classification_report(TAB_y_val, y_pred, digits=4))
    mcc = matthews_corrcoef(TAB_y_val, y_pred)
    print(f"MCC: {mcc:.4f}")
    # Confusion matrix
    cm = confusion_matrix(TAB_y_val, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap='Blues')
    plt.title("Confusion Matrix")
    plt.show()

    from sklearn.metrics import roc_curve, auc

    fpr, tpr, thresholds = roc_curve(TAB_y_val, y_pred_proba)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid()
    plt.tight_layout()
    plt.show()


    plt.figure(figsize=(6, 4))
    plt.hist(y_pred_proba[TAB_y_val == 0], bins=30, alpha=0.6, label='Class 0 (Does not require)', color='skyblue')
    plt.hist(y_pred_proba[TAB_y_val == 1], bins=30, alpha=0.6, label='Class 1 (Needs_Maintainence)', color='salmon')
    plt.xlabel('Predicted Probability')
    plt.ylabel('Count')
    plt.title('Histogram of Predicted Probabilities by Class')
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()
    
    fold += 1

# Final average AUC
print(f"\n==== Cross-Validation Complete ====")
print(f"Mean AUC: {np.mean(auc_scores):.4f} | Std AUC: {np.std(auc_scores):.4f}")
clf_M.save_model("Needs_Maintainence_tabnet_model")
==== Fold 1 ====
epoch 0  | loss: 0.73862 | val_auc: 0.51499 |  0:00:10s
epoch 1  | loss: 0.69416 | val_auc: 0.53331 |  0:00:20s
epoch 2  | loss: 0.69339 | val_auc: 0.53815 |  0:00:31s
epoch 3  | loss: 0.69292 | val_auc: 0.53089 |  0:00:41s
epoch 4  | loss: 0.69277 | val_auc: 0.54502 |  0:00:50s
epoch 5  | loss: 0.69199 | val_auc: 0.53341 |  0:01:00s
epoch 6  | loss: 0.69283 | val_auc: 0.53721 |  0:01:09s
epoch 7  | loss: 0.69194 | val_auc: 0.54032 |  0:01:21s
epoch 8  | loss: 0.69202 | val_auc: 0.5424  |  0:01:31s
epoch 9  | loss: 0.69166 | val_auc: 0.54612 |  0:01:51s
epoch 10 | loss: 0.69144 | val_auc: 0.53952 |  0:02:01s
epoch 11 | loss: 0.6909  | val_auc: 0.55133 |  0:02:12s
epoch 12 | loss: 0.69087 | val_auc: 0.5454  |  0:02:25s
epoch 13 | loss: 0.69113 | val_auc: 0.53313 |  0:02:36s
epoch 14 | loss: 0.69073 | val_auc: 0.54081 |  0:02:46s
epoch 15 | loss: 0.69203 | val_auc: 0.54166 |  0:02:56s
epoch 16 | loss: 0.69194 | val_auc: 0.536   |  0:03:06s
epoch 17 | loss: 0.69139 | val_auc: 0.54411 |  0:03:16s
epoch 18 | loss: 0.69065 | val_auc: 0.54168 |  0:03:26s
epoch 19 | loss: 0.69148 | val_auc: 0.54292 |  0:03:36s
epoch 20 | loss: 0.69137 | val_auc: 0.54744 |  0:03:46s
epoch 21 | loss: 0.69066 | val_auc: 0.54813 |  0:03:55s
epoch 22 | loss: 0.69123 | val_auc: 0.54103 |  0:04:05s
epoch 23 | loss: 0.6911  | val_auc: 0.54438 |  0:04:15s
epoch 24 | loss: 0.69162 | val_auc: 0.54191 |  0:04:24s
epoch 25 | loss: 0.69145 | val_auc: 0.54712 |  0:04:34s
epoch 26 | loss: 0.6899  | val_auc: 0.54965 |  0:04:43s
epoch 27 | loss: 0.69123 | val_auc: 0.54522 |  0:04:53s
epoch 28 | loss: 0.69039 | val_auc: 0.5529  |  0:05:02s
epoch 29 | loss: 0.69031 | val_auc: 0.54752 |  0:05:11s
epoch 30 | loss: 0.69056 | val_auc: 0.54231 |  0:05:20s
epoch 31 | loss: 0.69079 | val_auc: 0.54904 |  0:05:30s
epoch 32 | loss: 0.69075 | val_auc: 0.54459 |  0:05:39s
epoch 33 | loss: 0.69099 | val_auc: 0.53882 |  0:05:50s
epoch 34 | loss: 0.6911  | val_auc: 0.54586 |  0:06:00s
epoch 35 | loss: 0.69131 | val_auc: 0.52696 |  0:06:11s
epoch 36 | loss: 0.691   | val_auc: 0.5423  |  0:06:21s
epoch 37 | loss: 0.69104 | val_auc: 0.54299 |  0:06:30s
epoch 38 | loss: 0.69103 | val_auc: 0.54231 |  0:06:40s
epoch 39 | loss: 0.69053 | val_auc: 0.53998 |  0:06:50s
epoch 40 | loss: 0.69161 | val_auc: 0.5415  |  0:06:59s
epoch 41 | loss: 0.69078 | val_auc: 0.54198 |  0:07:09s
epoch 42 | loss: 0.69055 | val_auc: 0.54156 |  0:07:19s
epoch 43 | loss: 0.6903  | val_auc: 0.55016 |  0:07:28s
epoch 44 | loss: 0.69023 | val_auc: 0.55124 |  0:07:37s
epoch 45 | loss: 0.69046 | val_auc: 0.55141 |  0:07:47s
epoch 46 | loss: 0.69092 | val_auc: 0.54714 |  0:07:56s
epoch 47 | loss: 0.69081 | val_auc: 0.54517 |  0:08:05s
epoch 48 | loss: 0.69081 | val_auc: 0.54334 |  0:08:15s

Early stopping occurred at epoch 48 with best_epoch = 28 and best_val_auc = 0.5529
Fold 1 AUC: 0.5529
              precision    recall  f1-score   support

           0     0.4578    0.4670    0.4624      3730
           1     0.6216    0.6129    0.6172      5329

    accuracy                         0.5528      9059
   macro avg     0.5397    0.5399    0.5398      9059
weighted avg     0.5542    0.5528    0.5535      9059

MCC: 0.0797
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 2 ====
epoch 0  | loss: 0.74367 | val_auc: 0.51752 |  0:00:09s
epoch 1  | loss: 0.69457 | val_auc: 0.52189 |  0:00:18s
epoch 2  | loss: 0.69358 | val_auc: 0.5221  |  0:00:27s
epoch 3  | loss: 0.69297 | val_auc: 0.52514 |  0:00:38s
epoch 4  | loss: 0.69195 | val_auc: 0.5362  |  0:00:48s
epoch 5  | loss: 0.69141 | val_auc: 0.53066 |  0:00:58s
epoch 6  | loss: 0.69216 | val_auc: 0.54248 |  0:01:08s
epoch 7  | loss: 0.6919  | val_auc: 0.54002 |  0:01:18s
epoch 8  | loss: 0.69249 | val_auc: 0.53261 |  0:01:27s
epoch 9  | loss: 0.69238 | val_auc: 0.52982 |  0:01:37s
epoch 10 | loss: 0.69158 | val_auc: 0.53861 |  0:01:46s
epoch 11 | loss: 0.6917  | val_auc: 0.534   |  0:01:56s
epoch 12 | loss: 0.69119 | val_auc: 0.53371 |  0:02:06s
epoch 13 | loss: 0.6917  | val_auc: 0.53483 |  0:02:15s
epoch 14 | loss: 0.69118 | val_auc: 0.53192 |  0:02:25s
epoch 15 | loss: 0.69228 | val_auc: 0.53583 |  0:02:34s
epoch 16 | loss: 0.69103 | val_auc: 0.52722 |  0:02:44s
epoch 17 | loss: 0.69155 | val_auc: 0.52847 |  0:02:53s
epoch 18 | loss: 0.69104 | val_auc: 0.53745 |  0:03:03s
epoch 19 | loss: 0.69131 | val_auc: 0.5522  |  0:03:12s
epoch 20 | loss: 0.69    | val_auc: 0.55333 |  0:03:22s
epoch 21 | loss: 0.6905  | val_auc: 0.53932 |  0:03:32s
epoch 22 | loss: 0.69032 | val_auc: 0.5469  |  0:03:42s
epoch 23 | loss: 0.69092 | val_auc: 0.54167 |  0:03:51s
epoch 24 | loss: 0.69007 | val_auc: 0.54281 |  0:04:01s
epoch 25 | loss: 0.69042 | val_auc: 0.55122 |  0:04:11s
epoch 26 | loss: 0.6901  | val_auc: 0.54927 |  0:04:20s
epoch 27 | loss: 0.69133 | val_auc: 0.54574 |  0:04:29s
epoch 28 | loss: 0.69108 | val_auc: 0.54386 |  0:04:39s
epoch 29 | loss: 0.69033 | val_auc: 0.54586 |  0:04:48s
epoch 30 | loss: 0.68972 | val_auc: 0.54617 |  0:04:58s
epoch 31 | loss: 0.69002 | val_auc: 0.54678 |  0:05:07s
epoch 32 | loss: 0.68884 | val_auc: 0.54365 |  0:05:17s
epoch 33 | loss: 0.69046 | val_auc: 0.5424  |  0:05:26s
epoch 34 | loss: 0.69054 | val_auc: 0.54204 |  0:05:35s
epoch 35 | loss: 0.68904 | val_auc: 0.53871 |  0:05:45s
epoch 36 | loss: 0.68944 | val_auc: 0.53307 |  0:05:54s
epoch 37 | loss: 0.69127 | val_auc: 0.53761 |  0:06:04s
epoch 38 | loss: 0.69016 | val_auc: 0.53836 |  0:06:14s
epoch 39 | loss: 0.68898 | val_auc: 0.543   |  0:06:24s
epoch 40 | loss: 0.69073 | val_auc: 0.52934 |  0:06:34s

Early stopping occurred at epoch 40 with best_epoch = 20 and best_val_auc = 0.55333
Fold 2 AUC: 0.5533
              precision    recall  f1-score   support

           0     0.4479    0.5329    0.4867      3729
           1     0.6231    0.5404    0.5788      5329

    accuracy                         0.5373      9058
   macro avg     0.5355    0.5366    0.5328      9058
weighted avg     0.5510    0.5373    0.5409      9058

MCC: 0.0722
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 3 ====
epoch 0  | loss: 0.74304 | val_auc: 0.52747 |  0:00:11s
epoch 1  | loss: 0.69485 | val_auc: 0.51745 |  0:00:22s
epoch 2  | loss: 0.69377 | val_auc: 0.53443 |  0:00:34s
epoch 3  | loss: 0.69338 | val_auc: 0.52869 |  0:00:47s
epoch 4  | loss: 0.69234 | val_auc: 0.53622 |  0:00:59s
epoch 5  | loss: 0.69235 | val_auc: 0.5275  |  0:01:12s
epoch 6  | loss: 0.69239 | val_auc: 0.52821 |  0:01:25s
epoch 7  | loss: 0.69255 | val_auc: 0.53736 |  0:01:36s
epoch 8  | loss: 0.69244 | val_auc: 0.53165 |  0:01:46s
epoch 9  | loss: 0.6927  | val_auc: 0.53163 |  0:01:57s
epoch 10 | loss: 0.69191 | val_auc: 0.53879 |  0:02:07s
epoch 11 | loss: 0.69221 | val_auc: 0.54915 |  0:02:17s
epoch 12 | loss: 0.69166 | val_auc: 0.54173 |  0:02:27s
epoch 13 | loss: 0.69172 | val_auc: 0.54136 |  0:02:38s
epoch 14 | loss: 0.69246 | val_auc: 0.53891 |  0:02:48s
epoch 15 | loss: 0.69177 | val_auc: 0.54401 |  0:02:58s
epoch 16 | loss: 0.69112 | val_auc: 0.54718 |  0:03:08s
epoch 17 | loss: 0.69115 | val_auc: 0.54855 |  0:03:18s
epoch 18 | loss: 0.69055 | val_auc: 0.5427  |  0:03:29s
epoch 19 | loss: 0.69139 | val_auc: 0.54593 |  0:03:39s
epoch 20 | loss: 0.69151 | val_auc: 0.54426 |  0:03:49s
epoch 21 | loss: 0.69124 | val_auc: 0.5407  |  0:03:59s
epoch 22 | loss: 0.69182 | val_auc: 0.5373  |  0:04:09s
epoch 23 | loss: 0.69217 | val_auc: 0.53434 |  0:04:19s
epoch 24 | loss: 0.69191 | val_auc: 0.54526 |  0:04:28s
epoch 25 | loss: 0.69145 | val_auc: 0.53809 |  0:04:38s
epoch 26 | loss: 0.69112 | val_auc: 0.53639 |  0:04:47s
epoch 27 | loss: 0.69169 | val_auc: 0.53389 |  0:04:57s
epoch 28 | loss: 0.6912  | val_auc: 0.53249 |  0:05:07s
epoch 29 | loss: 0.69153 | val_auc: 0.53651 |  0:05:16s
epoch 30 | loss: 0.69081 | val_auc: 0.52509 |  0:05:25s
epoch 31 | loss: 0.69062 | val_auc: 0.52946 |  0:05:35s

Early stopping occurred at epoch 31 with best_epoch = 11 and best_val_auc = 0.54915
Fold 3 AUC: 0.5491
              precision    recall  f1-score   support

           0     0.4348    0.6461    0.5198      3730
           1     0.6245    0.4120    0.4964      5328

    accuracy                         0.5084      9058
   macro avg     0.5296    0.5290    0.5081      9058
weighted avg     0.5464    0.5084    0.5061      9058

MCC: 0.0587
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 4 ====
epoch 0  | loss: 0.73705 | val_auc: 0.51952 |  0:00:10s
epoch 1  | loss: 0.69329 | val_auc: 0.53089 |  0:00:20s
epoch 2  | loss: 0.69252 | val_auc: 0.54646 |  0:00:31s
epoch 3  | loss: 0.69228 | val_auc: 0.54065 |  0:00:41s
epoch 4  | loss: 0.69184 | val_auc: 0.534   |  0:00:51s
epoch 5  | loss: 0.69193 | val_auc: 0.54657 |  0:01:01s
epoch 6  | loss: 0.69151 | val_auc: 0.55287 |  0:01:10s
epoch 7  | loss: 0.69107 | val_auc: 0.54494 |  0:01:20s
epoch 8  | loss: 0.69122 | val_auc: 0.52761 |  0:01:29s
epoch 9  | loss: 0.69107 | val_auc: 0.55175 |  0:01:38s
epoch 10 | loss: 0.68997 | val_auc: 0.54981 |  0:01:48s
epoch 11 | loss: 0.69011 | val_auc: 0.54628 |  0:01:58s
epoch 12 | loss: 0.69029 | val_auc: 0.54834 |  0:02:07s
epoch 13 | loss: 0.68975 | val_auc: 0.54966 |  0:02:15s
epoch 14 | loss: 0.68997 | val_auc: 0.54721 |  0:02:25s
epoch 15 | loss: 0.69032 | val_auc: 0.55026 |  0:02:36s
epoch 16 | loss: 0.69028 | val_auc: 0.55701 |  0:02:47s
epoch 17 | loss: 0.69031 | val_auc: 0.54663 |  0:02:57s
epoch 18 | loss: 0.69004 | val_auc: 0.55328 |  0:03:08s
epoch 19 | loss: 0.68989 | val_auc: 0.55197 |  0:03:19s
epoch 20 | loss: 0.6898  | val_auc: 0.54563 |  0:03:30s
epoch 21 | loss: 0.68916 | val_auc: 0.55396 |  0:03:40s
epoch 22 | loss: 0.69026 | val_auc: 0.55576 |  0:03:51s
epoch 23 | loss: 0.68898 | val_auc: 0.55987 |  0:04:02s
epoch 24 | loss: 0.68901 | val_auc: 0.5621  |  0:04:14s
epoch 25 | loss: 0.68924 | val_auc: 0.56154 |  0:04:25s
epoch 26 | loss: 0.68936 | val_auc: 0.56151 |  0:04:34s
epoch 27 | loss: 0.68945 | val_auc: 0.55443 |  0:04:44s
epoch 28 | loss: 0.68945 | val_auc: 0.55818 |  0:04:53s
epoch 29 | loss: 0.68815 | val_auc: 0.55186 |  0:05:02s
epoch 30 | loss: 0.68838 | val_auc: 0.56158 |  0:05:11s
epoch 31 | loss: 0.68856 | val_auc: 0.56195 |  0:05:21s
epoch 32 | loss: 0.68918 | val_auc: 0.56096 |  0:05:30s
epoch 33 | loss: 0.68899 | val_auc: 0.56063 |  0:05:40s
epoch 34 | loss: 0.68933 | val_auc: 0.55781 |  0:05:49s
epoch 35 | loss: 0.68948 | val_auc: 0.55896 |  0:06:00s
epoch 36 | loss: 0.68867 | val_auc: 0.56113 |  0:06:12s
epoch 37 | loss: 0.68971 | val_auc: 0.56398 |  0:06:21s
epoch 38 | loss: 0.68879 | val_auc: 0.55588 |  0:06:31s
epoch 39 | loss: 0.6882  | val_auc: 0.5618  |  0:06:40s
epoch 40 | loss: 0.68834 | val_auc: 0.55574 |  0:06:49s
epoch 41 | loss: 0.68724 | val_auc: 0.56463 |  0:06:58s
epoch 42 | loss: 0.68692 | val_auc: 0.56536 |  0:07:07s
epoch 43 | loss: 0.68593 | val_auc: 0.56375 |  0:07:17s
epoch 44 | loss: 0.68509 | val_auc: 0.5708  |  0:07:27s
epoch 45 | loss: 0.68516 | val_auc: 0.56854 |  0:07:37s
epoch 46 | loss: 0.68481 | val_auc: 0.56642 |  0:07:49s
epoch 47 | loss: 0.68374 | val_auc: 0.57037 |  0:08:02s
epoch 48 | loss: 0.6832  | val_auc: 0.5712  |  0:08:13s
epoch 49 | loss: 0.68235 | val_auc: 0.56695 |  0:08:24s
epoch 50 | loss: 0.68079 | val_auc: 0.5731  |  0:08:35s
epoch 51 | loss: 0.67951 | val_auc: 0.56181 |  0:08:45s
epoch 52 | loss: 0.67734 | val_auc: 0.56499 |  0:08:56s
epoch 53 | loss: 0.6731  | val_auc: 0.58119 |  0:09:06s
epoch 54 | loss: 0.66857 | val_auc: 0.58393 |  0:09:17s
epoch 55 | loss: 0.66559 | val_auc: 0.58995 |  0:09:28s
epoch 56 | loss: 0.66147 | val_auc: 0.57975 |  0:09:38s
epoch 57 | loss: 0.65804 | val_auc: 0.59748 |  0:09:48s
epoch 58 | loss: 0.65249 | val_auc: 0.59459 |  0:09:58s
epoch 59 | loss: 0.64925 | val_auc: 0.58605 |  0:10:10s
epoch 60 | loss: 0.64754 | val_auc: 0.59586 |  0:10:21s
epoch 61 | loss: 0.6474  | val_auc: 0.58105 |  0:10:33s
epoch 62 | loss: 0.6467  | val_auc: 0.60681 |  0:10:43s
epoch 63 | loss: 0.64527 | val_auc: 0.59729 |  0:10:54s
epoch 64 | loss: 0.64216 | val_auc: 0.58857 |  0:11:04s
epoch 65 | loss: 0.64467 | val_auc: 0.57536 |  0:11:15s
epoch 66 | loss: 0.64411 | val_auc: 0.59619 |  0:11:25s
epoch 67 | loss: 0.64353 | val_auc: 0.57535 |  0:11:35s
epoch 68 | loss: 0.64291 | val_auc: 0.60058 |  0:11:45s
epoch 69 | loss: 0.6421  | val_auc: 0.59725 |  0:11:55s
epoch 70 | loss: 0.64248 | val_auc: 0.59396 |  0:12:06s
epoch 71 | loss: 0.64041 | val_auc: 0.58804 |  0:12:16s
epoch 72 | loss: 0.6408  | val_auc: 0.60002 |  0:12:26s
epoch 73 | loss: 0.64118 | val_auc: 0.57349 |  0:12:36s
epoch 74 | loss: 0.63999 | val_auc: 0.59896 |  0:12:46s
epoch 75 | loss: 0.64145 | val_auc: 0.58056 |  0:12:55s
epoch 76 | loss: 0.63728 | val_auc: 0.60125 |  0:13:04s
epoch 77 | loss: 0.63837 | val_auc: 0.60125 |  0:13:14s
epoch 78 | loss: 0.63685 | val_auc: 0.5927  |  0:13:24s
epoch 79 | loss: 0.63236 | val_auc: 0.61108 |  0:13:35s
epoch 80 | loss: 0.62904 | val_auc: 0.58998 |  0:13:45s
epoch 81 | loss: 0.62744 | val_auc: 0.60016 |  0:13:55s
epoch 82 | loss: 0.62566 | val_auc: 0.60276 |  0:14:05s
epoch 83 | loss: 0.62634 | val_auc: 0.61056 |  0:14:15s
epoch 84 | loss: 0.62373 | val_auc: 0.58483 |  0:14:26s
epoch 85 | loss: 0.62367 | val_auc: 0.60493 |  0:14:37s
epoch 86 | loss: 0.62276 | val_auc: 0.58662 |  0:14:48s
epoch 87 | loss: 0.62387 | val_auc: 0.60769 |  0:14:58s
epoch 88 | loss: 0.62379 | val_auc: 0.60117 |  0:15:07s
epoch 89 | loss: 0.62122 | val_auc: 0.6004  |  0:15:17s
epoch 90 | loss: 0.62167 | val_auc: 0.6048  |  0:15:27s
epoch 91 | loss: 0.61774 | val_auc: 0.60171 |  0:15:36s
epoch 92 | loss: 0.62124 | val_auc: 0.59599 |  0:15:45s
epoch 93 | loss: 0.62162 | val_auc: 0.58103 |  0:15:55s
epoch 94 | loss: 0.61876 | val_auc: 0.61844 |  0:16:05s
epoch 95 | loss: 0.62024 | val_auc: 0.56773 |  0:16:15s
epoch 96 | loss: 0.61768 | val_auc: 0.59562 |  0:16:25s
epoch 97 | loss: 0.6174  | val_auc: 0.61055 |  0:16:35s
epoch 98 | loss: 0.61906 | val_auc: 0.60331 |  0:16:46s
epoch 99 | loss: 0.6201  | val_auc: 0.6006  |  0:16:57s
Stop training because you reached max_epochs = 100 with best_epoch = 94 and best_val_auc = 0.61844
Fold 4 AUC: 0.6184
              precision    recall  f1-score   support

           0     0.4866    0.5244    0.5048      3730
           1     0.6479    0.6126    0.6298      5328

    accuracy                         0.5763      9058
   macro avg     0.5672    0.5685    0.5673      9058
weighted avg     0.5815    0.5763    0.5783      9058

MCC: 0.1357
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Fold 5 ====
epoch 0  | loss: 0.7381  | val_auc: 0.53926 |  0:00:09s
epoch 1  | loss: 0.69417 | val_auc: 0.54001 |  0:00:18s
epoch 2  | loss: 0.69334 | val_auc: 0.51119 |  0:00:28s
epoch 3  | loss: 0.69369 | val_auc: 0.52295 |  0:00:37s
epoch 4  | loss: 0.69285 | val_auc: 0.52012 |  0:00:46s
epoch 5  | loss: 0.69289 | val_auc: 0.53658 |  0:00:55s
epoch 6  | loss: 0.69219 | val_auc: 0.5421  |  0:01:04s
epoch 7  | loss: 0.69106 | val_auc: 0.54826 |  0:01:14s
epoch 8  | loss: 0.69029 | val_auc: 0.54747 |  0:01:24s
epoch 9  | loss: 0.69032 | val_auc: 0.54592 |  0:01:39s
epoch 10 | loss: 0.68874 | val_auc: 0.55297 |  0:01:51s
epoch 11 | loss: 0.69012 | val_auc: 0.54611 |  0:02:02s
epoch 12 | loss: 0.68962 | val_auc: 0.55438 |  0:02:14s
epoch 13 | loss: 0.69062 | val_auc: 0.54867 |  0:02:25s
epoch 14 | loss: 0.68956 | val_auc: 0.5475  |  0:02:39s
epoch 15 | loss: 0.69103 | val_auc: 0.55061 |  0:02:51s
epoch 16 | loss: 0.6898  | val_auc: 0.55431 |  0:03:03s
epoch 17 | loss: 0.68735 | val_auc: 0.55323 |  0:03:15s
epoch 18 | loss: 0.68786 | val_auc: 0.55427 |  0:03:29s
epoch 19 | loss: 0.6885  | val_auc: 0.55725 |  0:03:40s
epoch 20 | loss: 0.68761 | val_auc: 0.56605 |  0:03:52s
epoch 21 | loss: 0.68725 | val_auc: 0.56148 |  0:04:04s
epoch 22 | loss: 0.68766 | val_auc: 0.55726 |  0:04:14s
epoch 23 | loss: 0.68739 | val_auc: 0.55932 |  0:04:24s
epoch 24 | loss: 0.68669 | val_auc: 0.56105 |  0:04:34s
epoch 25 | loss: 0.68812 | val_auc: 0.55484 |  0:04:43s
epoch 26 | loss: 0.68616 | val_auc: 0.56329 |  0:04:52s
epoch 27 | loss: 0.68647 | val_auc: 0.56096 |  0:05:01s
epoch 28 | loss: 0.6869  | val_auc: 0.56062 |  0:05:10s
epoch 29 | loss: 0.6858  | val_auc: 0.56036 |  0:05:18s
epoch 30 | loss: 0.68557 | val_auc: 0.56423 |  0:05:27s
epoch 31 | loss: 0.68521 | val_auc: 0.55339 |  0:05:35s
epoch 32 | loss: 0.68604 | val_auc: 0.56167 |  0:05:43s
epoch 33 | loss: 0.68492 | val_auc: 0.55643 |  0:05:51s
epoch 34 | loss: 0.6846  | val_auc: 0.56229 |  0:05:59s
epoch 35 | loss: 0.68464 | val_auc: 0.55972 |  0:06:07s
epoch 36 | loss: 0.68277 | val_auc: 0.55484 |  0:06:15s
epoch 37 | loss: 0.68264 | val_auc: 0.56243 |  0:06:23s
epoch 38 | loss: 0.68409 | val_auc: 0.56235 |  0:06:31s
epoch 39 | loss: 0.68236 | val_auc: 0.55643 |  0:06:38s
epoch 40 | loss: 0.68221 | val_auc: 0.56639 |  0:06:45s
epoch 41 | loss: 0.68162 | val_auc: 0.56583 |  0:06:53s
epoch 42 | loss: 0.67926 | val_auc: 0.56331 |  0:07:00s
epoch 43 | loss: 0.6806  | val_auc: 0.56547 |  0:07:07s
epoch 44 | loss: 0.68012 | val_auc: 0.56219 |  0:07:15s
epoch 45 | loss: 0.67786 | val_auc: 0.56578 |  0:07:22s
epoch 46 | loss: 0.67827 | val_auc: 0.56375 |  0:07:31s
epoch 47 | loss: 0.67381 | val_auc: 0.57245 |  0:07:40s
epoch 48 | loss: 0.67332 | val_auc: 0.57125 |  0:07:49s
epoch 49 | loss: 0.6708  | val_auc: 0.57653 |  0:07:57s
epoch 50 | loss: 0.6678  | val_auc: 0.57729 |  0:08:06s
epoch 51 | loss: 0.66534 | val_auc: 0.58197 |  0:08:14s
epoch 52 | loss: 0.66391 | val_auc: 0.58182 |  0:08:24s
epoch 53 | loss: 0.66229 | val_auc: 0.58137 |  0:08:33s
epoch 54 | loss: 0.65908 | val_auc: 0.59146 |  0:08:42s
epoch 55 | loss: 0.65998 | val_auc: 0.59352 |  0:08:51s
epoch 56 | loss: 0.65774 | val_auc: 0.5932  |  0:09:00s
epoch 57 | loss: 0.65654 | val_auc: 0.59505 |  0:09:09s
epoch 58 | loss: 0.65287 | val_auc: 0.59746 |  0:09:18s
epoch 59 | loss: 0.64709 | val_auc: 0.60005 |  0:09:26s
epoch 60 | loss: 0.64252 | val_auc: 0.59609 |  0:09:35s
epoch 61 | loss: 0.63922 | val_auc: 0.60742 |  0:09:45s
epoch 62 | loss: 0.63946 | val_auc: 0.6078  |  0:09:54s
epoch 63 | loss: 0.63713 | val_auc: 0.60222 |  0:10:03s
epoch 64 | loss: 0.63655 | val_auc: 0.60803 |  0:10:11s
epoch 65 | loss: 0.63522 | val_auc: 0.6067  |  0:10:19s
epoch 66 | loss: 0.63349 | val_auc: 0.60649 |  0:10:27s
epoch 67 | loss: 0.63384 | val_auc: 0.60731 |  0:10:36s
epoch 68 | loss: 0.63134 | val_auc: 0.59559 |  0:10:45s
epoch 69 | loss: 0.63047 | val_auc: 0.58947 |  0:10:53s
epoch 70 | loss: 0.63179 | val_auc: 0.60461 |  0:11:00s
epoch 71 | loss: 0.62901 | val_auc: 0.60928 |  0:11:08s
epoch 72 | loss: 0.62975 | val_auc: 0.60007 |  0:11:17s
epoch 73 | loss: 0.62767 | val_auc: 0.60926 |  0:11:26s
epoch 74 | loss: 0.62735 | val_auc: 0.60034 |  0:11:34s
epoch 75 | loss: 0.62829 | val_auc: 0.60422 |  0:11:43s
epoch 76 | loss: 0.62898 | val_auc: 0.60282 |  0:11:51s
epoch 77 | loss: 0.62904 | val_auc: 0.59976 |  0:12:00s
epoch 78 | loss: 0.6247  | val_auc: 0.59464 |  0:12:09s
epoch 79 | loss: 0.62823 | val_auc: 0.6074  |  0:12:18s
epoch 80 | loss: 0.62719 | val_auc: 0.60294 |  0:12:26s
epoch 81 | loss: 0.62525 | val_auc: 0.61254 |  0:12:35s
epoch 82 | loss: 0.62501 | val_auc: 0.60337 |  0:12:44s
epoch 83 | loss: 0.62337 | val_auc: 0.61158 |  0:12:52s
epoch 84 | loss: 0.62474 | val_auc: 0.60949 |  0:13:00s
epoch 85 | loss: 0.62542 | val_auc: 0.60143 |  0:13:09s
epoch 86 | loss: 0.62309 | val_auc: 0.61275 |  0:13:17s
epoch 87 | loss: 0.62445 | val_auc: 0.61873 |  0:13:26s
epoch 88 | loss: 0.62233 | val_auc: 0.61826 |  0:13:34s
epoch 89 | loss: 0.62163 | val_auc: 0.62097 |  0:13:43s
epoch 90 | loss: 0.6205  | val_auc: 0.62121 |  0:13:51s
epoch 91 | loss: 0.6206  | val_auc: 0.61683 |  0:14:00s
epoch 92 | loss: 0.62121 | val_auc: 0.61265 |  0:14:09s
epoch 93 | loss: 0.62048 | val_auc: 0.60305 |  0:14:19s
epoch 94 | loss: 0.6186  | val_auc: 0.60847 |  0:14:28s
epoch 95 | loss: 0.62331 | val_auc: 0.60684 |  0:14:38s
epoch 96 | loss: 0.61866 | val_auc: 0.60686 |  0:14:48s
epoch 97 | loss: 0.62032 | val_auc: 0.60096 |  0:14:59s
epoch 98 | loss: 0.62132 | val_auc: 0.61791 |  0:15:09s
epoch 99 | loss: 0.62071 | val_auc: 0.60238 |  0:15:18s
Stop training because you reached max_epochs = 100 with best_epoch = 90 and best_val_auc = 0.62121
Fold 5 AUC: 0.6212
              precision    recall  f1-score   support

           0     0.9413    0.2193    0.3557      3730
           1     0.6444    0.9904    0.7808      5328

    accuracy                         0.6729      9058
   macro avg     0.7929    0.6049    0.5683      9058
weighted avg     0.7667    0.6729    0.6058      9058

MCC: 0.3505
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
==== Cross-Validation Complete ====
Mean AUC: 0.5790 | Std AUC: 0.0334
Successfully saved model at Needs_Maintainence_tabnet_model.zip
Out[114]:
'Needs_Maintainence_tabnet_model.zip'

Extra evaluation on untrained data¶

In [115]:
# Predict probabilities and labels

# Predict
TAB_x = TAB_X_val
TAB_y = TAB_y_val

y_pred_proba = clf_M.predict_proba(TAB_x)[:, 1]
y_pred = clf_M.predict(TAB_x)

# AUC
auc_score = roc_auc_score(TAB_y, y_pred_proba)
print(f"\nFinal Validation AUC: {auc_score:.4f}")

# Classification report
print("Classification Report:")
print(classification_report(TAB_y, y_pred, digits=4))

# MCC
mcc = matthews_corrcoef(TAB_y, y_pred)
print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")

# Confusion matrix
cm = confusion_matrix(TAB_y, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()

from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(TAB_y, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.tight_layout()
plt.show()


plt.figure(figsize=(6, 4))
plt.hist(y_pred_proba[TAB_y == 0], bins=30, alpha=0.6, label='Class 0 (No Failure)', color='skyblue')
plt.hist(y_pred_proba[TAB_y == 1], bins=30, alpha=0.6, label='Class 1 (Failure)', color='salmon')
plt.xlabel('Predicted Probability')
plt.ylabel('Count')
plt.title('Histogram of Predicted Probabilities by Class')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()

import seaborn as sns

# Explain method to get global feature importance
explain_matrix, masks = clf_M.explain(TAB_x)  # TAB_x should be your validation features

# Create DataFrame for feature importances
feature_importances_df = pd.DataFrame({
    'Feature': X_val.columns,
    'Importance': explain_matrix.sum(axis=0)
}).sort_values(by='Importance', ascending=False)

# Plot
plt.figure(figsize=(8, 6))
sns.barplot(data=feature_importances_df, x='Importance', y='Feature', palette='viridis')
plt.title("TabNet Feature Importances (Global Explanation)")
plt.tight_layout()
plt.show()
Final Validation AUC: 0.6212
Classification Report:
              precision    recall  f1-score   support

           0     0.9413    0.2193    0.3557      3730
           1     0.6444    0.9904    0.7808      5328

    accuracy                         0.6729      9058
   macro avg     0.7929    0.6049    0.5683      9058
weighted avg     0.7667    0.6729    0.6058      9058

Matthews Correlation Coefficient (MCC): 0.3505
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image